204 lines
5.4 KiB
Python
Executable File
204 lines
5.4 KiB
Python
Executable File
#! /usr/bin/env nix-shell
|
|
#! nix-shell -i python3 -p python3Packages.tinygrad python3Packages.numpy python3Packages.discordpy python3Packages.transformers python3Packages.tqdm python3Packages.flask
|
|
|
|
import queue
|
|
import flask
|
|
from flask import request, Response, jsonify
|
|
from tinygrad import Tensor, TinyJit, dtypes, Device
|
|
from tinygrad.nn.state import safe_load, load_state_dict
|
|
from transformers import AutoTokenizer
|
|
from model import Transformer
|
|
from tqdm import tqdm
|
|
import threading
|
|
import json
|
|
import time
|
|
import uuid
|
|
|
|
hypr = {
|
|
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
|
|
"encoding": "TinyLlama/TinyLlama_v1.1"
|
|
}
|
|
|
|
CHECKPOINT_PATH = 'gpt.safetensors'
|
|
|
|
msg_q = queue.Queue()
|
|
|
|
encoding = AutoTokenizer.from_pretrained(hypr['encoding'])
|
|
model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hypr["n_blocks"], hypr["block_size"])
|
|
load_state_dict(model, safe_load(CHECKPOINT_PATH))
|
|
Tensor.training = False
|
|
|
|
class messageOBJ:
|
|
def __init__(self,chat,maxGen=50,temp=0.7):
|
|
self.outputQueue = queue.Queue()
|
|
self.chat = encoding.encode(chat)
|
|
self.ready = False
|
|
self.chatLen = 0
|
|
self.maxGen = maxGen
|
|
self.runGen = 0
|
|
self.temp = temp
|
|
|
|
def __readyChat__(self):
|
|
""" Must only be called on tinygrad thread """
|
|
if not self.ready:
|
|
self.chat = Tensor(self.chat)
|
|
self.ready = True
|
|
#TODO trim chat tensor
|
|
|
|
def getTensor(self):
|
|
""" Must only be called on tinygrad thread """
|
|
self.__readyChat__()
|
|
self.chatLen = self.chat.shape[0]
|
|
pad_len = hypr['block_size'] - self.chatLen
|
|
return self.chat.pad((0,pad_len))
|
|
|
|
def step(self):
|
|
self.runGen += 1
|
|
return self.runGen < self.maxGen
|
|
|
|
def add(self,token):
|
|
self.chat = self.chat.cat(token)
|
|
self.outputQueue.put(token.numpy()[0])
|
|
|
|
def finish(self):
|
|
self.outputQueue.put(None)
|
|
|
|
|
|
@TinyJit
|
|
def run_model(input_buffer):
|
|
""" run model on gpu """
|
|
return model(input_buffer)
|
|
|
|
def inference_worker():
|
|
""" consume tasks from que """
|
|
BatchSize = 8
|
|
NewList = [None] * BatchSize
|
|
|
|
while True:
|
|
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
|
|
i = NewList.index(None)
|
|
NewList[i] = msg_q.get()
|
|
|
|
batch = []
|
|
for i in range(BatchSize):
|
|
t = None
|
|
if not NewList[i]:
|
|
t = Tensor.zeros(hypr['block_size'])
|
|
else:
|
|
msgobj = NewList[i]
|
|
t = msgobj.getTensor()
|
|
|
|
|
|
batch.append(t)
|
|
|
|
chat_tensor = batch[0].stack(*batch[1:])
|
|
logits = model(chat_tensor)
|
|
|
|
for i in range(BatchSize):
|
|
if NewList[i] is None:
|
|
continue
|
|
msgobj = NewList[i]
|
|
if msgobj.step():
|
|
tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1)
|
|
msgobj.add(tok)
|
|
else:
|
|
msgobj.finish()
|
|
NewList[i] = None
|
|
|
|
app = flask.Flask(__name__)
|
|
|
|
def fullReturn(msgobj,model):
|
|
ids = []
|
|
while True:
|
|
token = msgobj.outputQueue.get()
|
|
if token is None:
|
|
break
|
|
ids.append(token)
|
|
|
|
return jsonify({
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": encoding.decode(ids)
|
|
},
|
|
"finish_reason": "stop"
|
|
}]
|
|
})
|
|
|
|
def chunkReturn(msgobj,model):
|
|
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
|
ids = []
|
|
old = ""
|
|
while True:
|
|
token_id = msgobj.outputQueue.get()
|
|
if token_id is None:
|
|
break
|
|
ids.append(token_id)
|
|
tmp = encoding.decode(ids)
|
|
word = tmp[len(old):]
|
|
old = tmp
|
|
|
|
chunk = {
|
|
"id": chat_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {"content": word},
|
|
"finish_reason": None
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
@app.route('/v1/chat/completions', methods=['POST'])
|
|
def completions():
|
|
data = request.json
|
|
messageArray = data.get("messages", [])
|
|
stream = data.get("stream", False)
|
|
model = data.get("model", "RatChat")
|
|
temp = data.get("temperature", 0.7)
|
|
maxGen = data.get("max_tokens", 50)
|
|
chat = messageArray[-1]["content"] if messageArray else ""
|
|
msgobj = messageOBJ(chat,maxGen,temp)
|
|
msg_q.put(msgobj)
|
|
|
|
if stream:
|
|
return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"}
|
|
else:
|
|
return fullReturn(msgobj,model)
|
|
|
|
@app.route('/v1/models', methods=['GET'])
|
|
def list_models():
|
|
return jsonify({
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "RatChat",
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "tinygrad"
|
|
}
|
|
]
|
|
})
|
|
|
|
def apiStart():
|
|
""" start api """
|
|
app.run()
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
print(Device.DEFAULT)
|
|
print("warming up")
|
|
#warmup(200)
|
|
t = threading.Thread(target=apiStart, daemon=True)
|
|
t.start()
|
|
inference_worker()
|