basic version of OpenAI api implimented
This commit is contained in:
195
bot.py
195
bot.py
@@ -3,13 +3,16 @@
|
||||
|
||||
import queue
|
||||
import flask
|
||||
from flask import request, Response, jsonify
|
||||
from tinygrad import Tensor, TinyJit, dtypes, Device
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
from transformers import AutoTokenizer
|
||||
from model import Transformer
|
||||
from tqdm import tqdm
|
||||
import threading
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
|
||||
hypr = {
|
||||
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
|
||||
@@ -25,6 +28,41 @@ model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hy
|
||||
load_state_dict(model, safe_load(CHECKPOINT_PATH))
|
||||
Tensor.training = False
|
||||
|
||||
class messageOBJ:
|
||||
def __init__(self,chat,maxGen=50,temp=0.7):
|
||||
self.outputQueue = queue.Queue()
|
||||
self.chat = encoding.encode(chat)
|
||||
self.ready = False
|
||||
self.chatLen = 0
|
||||
self.maxGen = maxGen
|
||||
self.runGen = 0
|
||||
self.temp = temp
|
||||
|
||||
def __readyChat__(self):
|
||||
""" Must only be called on tinygrad thread """
|
||||
if not self.ready:
|
||||
self.chat = Tensor(self.chat)
|
||||
self.ready = True
|
||||
#TODO trim chat tensor
|
||||
|
||||
def getTensor(self):
|
||||
""" Must only be called on tinygrad thread """
|
||||
self.__readyChat__()
|
||||
self.chatLen = self.chat.shape[0]
|
||||
pad_len = hypr['block_size'] - self.chatLen
|
||||
return self.chat.pad((0,pad_len))
|
||||
|
||||
def step(self):
|
||||
self.runGen += 1
|
||||
return self.runGen < self.maxGen
|
||||
|
||||
def add(self,token):
|
||||
self.chat = self.chat.cat(token)
|
||||
self.outputQueue.put(token.numpy()[0])
|
||||
|
||||
def finish(self):
|
||||
self.outputQueue.put(None)
|
||||
|
||||
|
||||
@TinyJit
|
||||
def run_model(input_buffer):
|
||||
@@ -33,15 +71,13 @@ def run_model(input_buffer):
|
||||
|
||||
def inference_worker():
|
||||
""" consume tasks from que """
|
||||
BatchSize=2
|
||||
BatchSize = 8
|
||||
NewList = [None] * BatchSize
|
||||
|
||||
import time
|
||||
while True:
|
||||
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
|
||||
i = NewList.index(None)
|
||||
out,inp = msg_q.get()
|
||||
NewList[i] = (out,inp,None)
|
||||
NewList[i] = msg_q.get()
|
||||
|
||||
batch = []
|
||||
for i in range(BatchSize):
|
||||
@@ -49,82 +85,107 @@ def inference_worker():
|
||||
if not NewList[i]:
|
||||
t = Tensor.zeros(hypr['block_size'])
|
||||
else:
|
||||
_, t, _ = NewList[i]
|
||||
if not isinstance(t, Tensor):
|
||||
t = Tensor(t)
|
||||
l = t.shape[0]
|
||||
pad_len = hypr['block_size'] - l
|
||||
a,b,_ = NewList[i]
|
||||
NewList[i] = (a,t,l)
|
||||
|
||||
t = t.pad((0,pad_len))
|
||||
else:
|
||||
#t = t[:-hypr['block_size']]
|
||||
l = t.shape[0]
|
||||
pad_len = hypr['block_size'] - l
|
||||
t = t.pad((0,pad_len))
|
||||
msgobj = NewList[i]
|
||||
t = msgobj.getTensor()
|
||||
|
||||
|
||||
batch.append(t)
|
||||
chat_tensor = batch[0].stack(*batch[1:])
|
||||
#infince here
|
||||
|
||||
chat_tensor = batch[0].stack(*batch[1:])
|
||||
logits = model(chat_tensor)
|
||||
|
||||
#return
|
||||
for i in range(BatchSize):
|
||||
if NewList[i] is None:
|
||||
continue
|
||||
out, t, lenth = NewList[i]
|
||||
if lenth < 15:
|
||||
tok = (logits[i, lenth-1, :] / 0.7).softmax().multinomial(1)
|
||||
inp = t.cat(tok)
|
||||
out.put(tok.numpy()[0])
|
||||
NewList[i] = (out,inp,(lenth+1))
|
||||
msgobj = NewList[i]
|
||||
if msgobj.step():
|
||||
tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1)
|
||||
msgobj.add(tok)
|
||||
else:
|
||||
print(encoding.decode(chat_tensor[i].numpy().astype(int))[:25])
|
||||
out.shutdown()
|
||||
msgobj.finish()
|
||||
NewList[i] = None
|
||||
|
||||
|
||||
def warmup(count):
|
||||
""" run count times with random data """
|
||||
import random
|
||||
tokens = encoding.encode("")
|
||||
tokens = Tensor([tokens])
|
||||
for i in tqdm(range(count)):
|
||||
pad_len = hypr['block_size'] - tokens.shape[1]
|
||||
input_buffer = tokens.pad(((0, 0), (0, pad_len))).contiguous()
|
||||
out = model(input_buffer)
|
||||
token_tensor = (out[:, tokens.shape[1] - 1, :] / 0.7).softmax().multinomial(1)
|
||||
tokens = tokens.cat(token_tensor, dim=1).realize()
|
||||
tokens = tokens[:-hypr['block_size']]
|
||||
|
||||
def test(msg):
|
||||
tokens = queue.Queue()
|
||||
inp = encoding.encode(msg)
|
||||
t = []
|
||||
|
||||
msg_q.put((tokens,inp))
|
||||
yield("Start:")
|
||||
while True:
|
||||
try:
|
||||
i = tokens.get()
|
||||
t.append(i)
|
||||
yield(f"{i},")
|
||||
except:
|
||||
break
|
||||
txt = encoding.decode(t)
|
||||
yield f"\n{txt}"
|
||||
return
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
|
||||
from flask import request
|
||||
@app.route('/',methods=['POST'])
|
||||
def complete():
|
||||
user_string = request.form.get('input', 'Default prompt')
|
||||
return test(user_string),{"Content-Type": "text"}
|
||||
def fullReturn(msgobj,model):
|
||||
ids = []
|
||||
while True:
|
||||
token = msgobj.outputQueue.get()
|
||||
if token is None:
|
||||
break
|
||||
ids.append(token)
|
||||
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": encoding.decode(ids)
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
})
|
||||
|
||||
def chunkReturn(msgobj,model):
|
||||
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
ids = []
|
||||
old = ""
|
||||
while True:
|
||||
token_id = msgobj.outputQueue.get()
|
||||
if token_id is None:
|
||||
break
|
||||
ids.append(token_id)
|
||||
tmp = encoding.decode(ids)
|
||||
word = tmp[len(old):]
|
||||
old = tmp
|
||||
|
||||
chunk = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": word},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
def completions():
|
||||
data = request.json
|
||||
messageArray = data.get("messages", [])
|
||||
stream = data.get("stream", False)
|
||||
model = data.get("model", "RatChat")
|
||||
chat = messageArray[-1]["content"] if messageArray else ""
|
||||
msgobj = messageOBJ(chat)
|
||||
msg_q.put(msgobj)
|
||||
|
||||
if stream:
|
||||
return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"}
|
||||
else:
|
||||
return fullReturn(msgobj,model)
|
||||
|
||||
@app.route('/v1/models', methods=['GET'])
|
||||
def list_models():
|
||||
return jsonify({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "RatChat",
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "tinygrad"
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
def apiStart():
|
||||
""" start api """
|
||||
|
||||
Reference in New Issue
Block a user