Files
llmServer/bot.py

204 lines
5.4 KiB
Python
Executable File

#! /usr/bin/env nix-shell
#! nix-shell -i python3 -p python3Packages.tinygrad python3Packages.numpy python3Packages.discordpy python3Packages.transformers python3Packages.tqdm python3Packages.flask
import queue
import flask
from flask import request, Response, jsonify
from tinygrad import Tensor, TinyJit, dtypes, Device
from tinygrad.nn.state import safe_load, load_state_dict
from transformers import AutoTokenizer
from model import Transformer
from tqdm import tqdm
import threading
import json
import time
import uuid
hypr = {
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
"encoding": "TinyLlama/TinyLlama_v1.1"
}
CHECKPOINT_PATH = 'gpt.safetensors'
msg_q = queue.Queue()
encoding = AutoTokenizer.from_pretrained(hypr['encoding'])
model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hypr["n_blocks"], hypr["block_size"])
load_state_dict(model, safe_load(CHECKPOINT_PATH))
Tensor.training = False
class messageOBJ:
def __init__(self,chat,maxGen=50,temp=0.7):
self.outputQueue = queue.Queue()
self.chat = encoding.encode(chat)
self.ready = False
self.chatLen = 0
self.maxGen = maxGen
self.runGen = 0
self.temp = temp
def __readyChat__(self):
""" Must only be called on tinygrad thread """
if not self.ready:
self.chat = Tensor(self.chat)
self.ready = True
#TODO trim chat tensor
def getTensor(self):
""" Must only be called on tinygrad thread """
self.__readyChat__()
self.chatLen = self.chat.shape[0]
pad_len = hypr['block_size'] - self.chatLen
return self.chat.pad((0,pad_len))
def step(self):
self.runGen += 1
return self.runGen < self.maxGen
def add(self,token):
self.chat = self.chat.cat(token)
self.outputQueue.put(token.numpy()[0])
def finish(self):
self.outputQueue.put(None)
@TinyJit
def run_model(input_buffer):
""" run model on gpu """
return model(input_buffer)
def inference_worker():
""" consume tasks from que """
BatchSize = 8
NewList = [None] * BatchSize
while True:
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
i = NewList.index(None)
NewList[i] = msg_q.get()
batch = []
for i in range(BatchSize):
t = None
if not NewList[i]:
t = Tensor.zeros(hypr['block_size'])
else:
msgobj = NewList[i]
t = msgobj.getTensor()
batch.append(t)
chat_tensor = batch[0].stack(*batch[1:])
logits = model(chat_tensor)
for i in range(BatchSize):
if NewList[i] is None:
continue
msgobj = NewList[i]
if msgobj.step():
tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1)
msgobj.add(tok)
else:
msgobj.finish()
NewList[i] = None
app = flask.Flask(__name__)
def fullReturn(msgobj,model):
ids = []
while True:
token = msgobj.outputQueue.get()
if token is None:
break
ids.append(token)
return jsonify({
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": encoding.decode(ids)
},
"finish_reason": "stop"
}]
})
def chunkReturn(msgobj,model):
chat_id = f"chatcmpl-{uuid.uuid4()}"
ids = []
old = ""
while True:
token_id = msgobj.outputQueue.get()
if token_id is None:
break
ids.append(token_id)
tmp = encoding.decode(ids)
word = tmp[len(old):]
old = tmp
chunk = {
"id": chat_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": word},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
@app.route('/v1/chat/completions', methods=['POST'])
def completions():
data = request.json
messageArray = data.get("messages", [])
stream = data.get("stream", False)
model = data.get("model", "RatChat")
temp = data.get("temperature", 0.7)
maxGen = data.get("max_tokens", 50)
chat = messageArray[-1]["content"] if messageArray else ""
msgobj = messageOBJ(chat,maxGen,temp)
msg_q.put(msgobj)
if stream:
return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"}
else:
return fullReturn(msgobj,model)
@app.route('/v1/models', methods=['GET'])
def list_models():
return jsonify({
"object": "list",
"data": [
{
"id": "RatChat",
"object": "model",
"created": int(time.time()),
"owned_by": "tinygrad"
}
]
})
def apiStart():
""" start api """
app.run()
pass
if __name__ == "__main__":
print(Device.DEFAULT)
print("warming up")
#warmup(200)
t = threading.Thread(target=apiStart, daemon=True)
t.start()
inference_worker()