basic version of OpenAI api implimented
This commit is contained in:
195
bot.py
195
bot.py
@@ -3,13 +3,16 @@
|
|||||||
|
|
||||||
import queue
|
import queue
|
||||||
import flask
|
import flask
|
||||||
|
from flask import request, Response, jsonify
|
||||||
from tinygrad import Tensor, TinyJit, dtypes, Device
|
from tinygrad import Tensor, TinyJit, dtypes, Device
|
||||||
from tinygrad.nn.state import safe_load, load_state_dict
|
from tinygrad.nn.state import safe_load, load_state_dict
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from model import Transformer
|
from model import Transformer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import threading
|
import threading
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
hypr = {
|
hypr = {
|
||||||
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
|
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
|
||||||
@@ -25,6 +28,41 @@ model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hy
|
|||||||
load_state_dict(model, safe_load(CHECKPOINT_PATH))
|
load_state_dict(model, safe_load(CHECKPOINT_PATH))
|
||||||
Tensor.training = False
|
Tensor.training = False
|
||||||
|
|
||||||
|
class messageOBJ:
|
||||||
|
def __init__(self,chat,maxGen=50,temp=0.7):
|
||||||
|
self.outputQueue = queue.Queue()
|
||||||
|
self.chat = encoding.encode(chat)
|
||||||
|
self.ready = False
|
||||||
|
self.chatLen = 0
|
||||||
|
self.maxGen = maxGen
|
||||||
|
self.runGen = 0
|
||||||
|
self.temp = temp
|
||||||
|
|
||||||
|
def __readyChat__(self):
|
||||||
|
""" Must only be called on tinygrad thread """
|
||||||
|
if not self.ready:
|
||||||
|
self.chat = Tensor(self.chat)
|
||||||
|
self.ready = True
|
||||||
|
#TODO trim chat tensor
|
||||||
|
|
||||||
|
def getTensor(self):
|
||||||
|
""" Must only be called on tinygrad thread """
|
||||||
|
self.__readyChat__()
|
||||||
|
self.chatLen = self.chat.shape[0]
|
||||||
|
pad_len = hypr['block_size'] - self.chatLen
|
||||||
|
return self.chat.pad((0,pad_len))
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.runGen += 1
|
||||||
|
return self.runGen < self.maxGen
|
||||||
|
|
||||||
|
def add(self,token):
|
||||||
|
self.chat = self.chat.cat(token)
|
||||||
|
self.outputQueue.put(token.numpy()[0])
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.outputQueue.put(None)
|
||||||
|
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def run_model(input_buffer):
|
def run_model(input_buffer):
|
||||||
@@ -33,15 +71,13 @@ def run_model(input_buffer):
|
|||||||
|
|
||||||
def inference_worker():
|
def inference_worker():
|
||||||
""" consume tasks from que """
|
""" consume tasks from que """
|
||||||
BatchSize=2
|
BatchSize = 8
|
||||||
NewList = [None] * BatchSize
|
NewList = [None] * BatchSize
|
||||||
|
|
||||||
import time
|
|
||||||
while True:
|
while True:
|
||||||
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
|
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
|
||||||
i = NewList.index(None)
|
i = NewList.index(None)
|
||||||
out,inp = msg_q.get()
|
NewList[i] = msg_q.get()
|
||||||
NewList[i] = (out,inp,None)
|
|
||||||
|
|
||||||
batch = []
|
batch = []
|
||||||
for i in range(BatchSize):
|
for i in range(BatchSize):
|
||||||
@@ -49,82 +85,107 @@ def inference_worker():
|
|||||||
if not NewList[i]:
|
if not NewList[i]:
|
||||||
t = Tensor.zeros(hypr['block_size'])
|
t = Tensor.zeros(hypr['block_size'])
|
||||||
else:
|
else:
|
||||||
_, t, _ = NewList[i]
|
msgobj = NewList[i]
|
||||||
if not isinstance(t, Tensor):
|
t = msgobj.getTensor()
|
||||||
t = Tensor(t)
|
|
||||||
l = t.shape[0]
|
|
||||||
pad_len = hypr['block_size'] - l
|
|
||||||
a,b,_ = NewList[i]
|
|
||||||
NewList[i] = (a,t,l)
|
|
||||||
|
|
||||||
t = t.pad((0,pad_len))
|
|
||||||
else:
|
|
||||||
#t = t[:-hypr['block_size']]
|
|
||||||
l = t.shape[0]
|
|
||||||
pad_len = hypr['block_size'] - l
|
|
||||||
t = t.pad((0,pad_len))
|
|
||||||
|
|
||||||
|
|
||||||
batch.append(t)
|
batch.append(t)
|
||||||
chat_tensor = batch[0].stack(*batch[1:])
|
|
||||||
#infince here
|
|
||||||
|
|
||||||
|
chat_tensor = batch[0].stack(*batch[1:])
|
||||||
logits = model(chat_tensor)
|
logits = model(chat_tensor)
|
||||||
|
|
||||||
#return
|
|
||||||
for i in range(BatchSize):
|
for i in range(BatchSize):
|
||||||
if NewList[i] is None:
|
if NewList[i] is None:
|
||||||
continue
|
continue
|
||||||
out, t, lenth = NewList[i]
|
msgobj = NewList[i]
|
||||||
if lenth < 15:
|
if msgobj.step():
|
||||||
tok = (logits[i, lenth-1, :] / 0.7).softmax().multinomial(1)
|
tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1)
|
||||||
inp = t.cat(tok)
|
msgobj.add(tok)
|
||||||
out.put(tok.numpy()[0])
|
|
||||||
NewList[i] = (out,inp,(lenth+1))
|
|
||||||
else:
|
else:
|
||||||
print(encoding.decode(chat_tensor[i].numpy().astype(int))[:25])
|
msgobj.finish()
|
||||||
out.shutdown()
|
|
||||||
NewList[i] = None
|
NewList[i] = None
|
||||||
|
|
||||||
|
|
||||||
def warmup(count):
|
|
||||||
""" run count times with random data """
|
|
||||||
import random
|
|
||||||
tokens = encoding.encode("")
|
|
||||||
tokens = Tensor([tokens])
|
|
||||||
for i in tqdm(range(count)):
|
|
||||||
pad_len = hypr['block_size'] - tokens.shape[1]
|
|
||||||
input_buffer = tokens.pad(((0, 0), (0, pad_len))).contiguous()
|
|
||||||
out = model(input_buffer)
|
|
||||||
token_tensor = (out[:, tokens.shape[1] - 1, :] / 0.7).softmax().multinomial(1)
|
|
||||||
tokens = tokens.cat(token_tensor, dim=1).realize()
|
|
||||||
tokens = tokens[:-hypr['block_size']]
|
|
||||||
|
|
||||||
def test(msg):
|
|
||||||
tokens = queue.Queue()
|
|
||||||
inp = encoding.encode(msg)
|
|
||||||
t = []
|
|
||||||
|
|
||||||
msg_q.put((tokens,inp))
|
|
||||||
yield("Start:")
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
i = tokens.get()
|
|
||||||
t.append(i)
|
|
||||||
yield(f"{i},")
|
|
||||||
except:
|
|
||||||
break
|
|
||||||
txt = encoding.decode(t)
|
|
||||||
yield f"\n{txt}"
|
|
||||||
return
|
|
||||||
|
|
||||||
app = flask.Flask(__name__)
|
app = flask.Flask(__name__)
|
||||||
|
|
||||||
from flask import request
|
def fullReturn(msgobj,model):
|
||||||
@app.route('/',methods=['POST'])
|
ids = []
|
||||||
def complete():
|
while True:
|
||||||
user_string = request.form.get('input', 'Default prompt')
|
token = msgobj.outputQueue.get()
|
||||||
return test(user_string),{"Content-Type": "text"}
|
if token is None:
|
||||||
|
break
|
||||||
|
ids.append(token)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": encoding.decode(ids)
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
|
||||||
|
def chunkReturn(msgobj,model):
|
||||||
|
chat_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
|
ids = []
|
||||||
|
old = ""
|
||||||
|
while True:
|
||||||
|
token_id = msgobj.outputQueue.get()
|
||||||
|
if token_id is None:
|
||||||
|
break
|
||||||
|
ids.append(token_id)
|
||||||
|
tmp = encoding.decode(ids)
|
||||||
|
word = tmp[len(old):]
|
||||||
|
old = tmp
|
||||||
|
|
||||||
|
chunk = {
|
||||||
|
"id": chat_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": word},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/v1/chat/completions', methods=['POST'])
|
||||||
|
def completions():
|
||||||
|
data = request.json
|
||||||
|
messageArray = data.get("messages", [])
|
||||||
|
stream = data.get("stream", False)
|
||||||
|
model = data.get("model", "RatChat")
|
||||||
|
chat = messageArray[-1]["content"] if messageArray else ""
|
||||||
|
msgobj = messageOBJ(chat)
|
||||||
|
msg_q.put(msgobj)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"}
|
||||||
|
else:
|
||||||
|
return fullReturn(msgobj,model)
|
||||||
|
|
||||||
|
@app.route('/v1/models', methods=['GET'])
|
||||||
|
def list_models():
|
||||||
|
return jsonify({
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "RatChat",
|
||||||
|
"object": "model",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"owned_by": "tinygrad"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
def apiStart():
|
def apiStart():
|
||||||
""" start api """
|
""" start api """
|
||||||
|
|||||||
Reference in New Issue
Block a user