basic version of OpenAI api implimented

This commit is contained in:
k
2026-03-04 00:28:46 -05:00
parent 4196b70681
commit a37d05a2de

195
bot.py
View File

@@ -3,13 +3,16 @@
import queue import queue
import flask import flask
from flask import request, Response, jsonify
from tinygrad import Tensor, TinyJit, dtypes, Device from tinygrad import Tensor, TinyJit, dtypes, Device
from tinygrad.nn.state import safe_load, load_state_dict from tinygrad.nn.state import safe_load, load_state_dict
from transformers import AutoTokenizer from transformers import AutoTokenizer
from model import Transformer from model import Transformer
from tqdm import tqdm from tqdm import tqdm
import threading import threading
import json
import time
import uuid
hypr = { hypr = {
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512, "embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
@@ -25,6 +28,41 @@ model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hy
load_state_dict(model, safe_load(CHECKPOINT_PATH)) load_state_dict(model, safe_load(CHECKPOINT_PATH))
Tensor.training = False Tensor.training = False
class messageOBJ:
def __init__(self,chat,maxGen=50,temp=0.7):
self.outputQueue = queue.Queue()
self.chat = encoding.encode(chat)
self.ready = False
self.chatLen = 0
self.maxGen = maxGen
self.runGen = 0
self.temp = temp
def __readyChat__(self):
""" Must only be called on tinygrad thread """
if not self.ready:
self.chat = Tensor(self.chat)
self.ready = True
#TODO trim chat tensor
def getTensor(self):
""" Must only be called on tinygrad thread """
self.__readyChat__()
self.chatLen = self.chat.shape[0]
pad_len = hypr['block_size'] - self.chatLen
return self.chat.pad((0,pad_len))
def step(self):
self.runGen += 1
return self.runGen < self.maxGen
def add(self,token):
self.chat = self.chat.cat(token)
self.outputQueue.put(token.numpy()[0])
def finish(self):
self.outputQueue.put(None)
@TinyJit @TinyJit
def run_model(input_buffer): def run_model(input_buffer):
@@ -33,15 +71,13 @@ def run_model(input_buffer):
def inference_worker(): def inference_worker():
""" consume tasks from que """ """ consume tasks from que """
BatchSize=2 BatchSize = 8
NewList = [None] * BatchSize NewList = [None] * BatchSize
import time
while True: while True:
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList): if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
i = NewList.index(None) i = NewList.index(None)
out,inp = msg_q.get() NewList[i] = msg_q.get()
NewList[i] = (out,inp,None)
batch = [] batch = []
for i in range(BatchSize): for i in range(BatchSize):
@@ -49,82 +85,107 @@ def inference_worker():
if not NewList[i]: if not NewList[i]:
t = Tensor.zeros(hypr['block_size']) t = Tensor.zeros(hypr['block_size'])
else: else:
_, t, _ = NewList[i] msgobj = NewList[i]
if not isinstance(t, Tensor): t = msgobj.getTensor()
t = Tensor(t)
l = t.shape[0]
pad_len = hypr['block_size'] - l
a,b,_ = NewList[i]
NewList[i] = (a,t,l)
t = t.pad((0,pad_len))
else:
#t = t[:-hypr['block_size']]
l = t.shape[0]
pad_len = hypr['block_size'] - l
t = t.pad((0,pad_len))
batch.append(t) batch.append(t)
chat_tensor = batch[0].stack(*batch[1:])
#infince here
chat_tensor = batch[0].stack(*batch[1:])
logits = model(chat_tensor) logits = model(chat_tensor)
#return
for i in range(BatchSize): for i in range(BatchSize):
if NewList[i] is None: if NewList[i] is None:
continue continue
out, t, lenth = NewList[i] msgobj = NewList[i]
if lenth < 15: if msgobj.step():
tok = (logits[i, lenth-1, :] / 0.7).softmax().multinomial(1) tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1)
inp = t.cat(tok) msgobj.add(tok)
out.put(tok.numpy()[0])
NewList[i] = (out,inp,(lenth+1))
else: else:
print(encoding.decode(chat_tensor[i].numpy().astype(int))[:25]) msgobj.finish()
out.shutdown()
NewList[i] = None NewList[i] = None
def warmup(count):
""" run count times with random data """
import random
tokens = encoding.encode("")
tokens = Tensor([tokens])
for i in tqdm(range(count)):
pad_len = hypr['block_size'] - tokens.shape[1]
input_buffer = tokens.pad(((0, 0), (0, pad_len))).contiguous()
out = model(input_buffer)
token_tensor = (out[:, tokens.shape[1] - 1, :] / 0.7).softmax().multinomial(1)
tokens = tokens.cat(token_tensor, dim=1).realize()
tokens = tokens[:-hypr['block_size']]
def test(msg):
tokens = queue.Queue()
inp = encoding.encode(msg)
t = []
msg_q.put((tokens,inp))
yield("Start:")
while True:
try:
i = tokens.get()
t.append(i)
yield(f"{i},")
except:
break
txt = encoding.decode(t)
yield f"\n{txt}"
return
app = flask.Flask(__name__) app = flask.Flask(__name__)
from flask import request def fullReturn(msgobj,model):
@app.route('/',methods=['POST']) ids = []
def complete(): while True:
user_string = request.form.get('input', 'Default prompt') token = msgobj.outputQueue.get()
return test(user_string),{"Content-Type": "text"} if token is None:
break
ids.append(token)
return jsonify({
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": encoding.decode(ids)
},
"finish_reason": "stop"
}]
})
def chunkReturn(msgobj,model):
chat_id = f"chatcmpl-{uuid.uuid4()}"
ids = []
old = ""
while True:
token_id = msgobj.outputQueue.get()
if token_id is None:
break
ids.append(token_id)
tmp = encoding.decode(ids)
word = tmp[len(old):]
old = tmp
chunk = {
"id": chat_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": word},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
@app.route('/v1/chat/completions', methods=['POST'])
def completions():
data = request.json
messageArray = data.get("messages", [])
stream = data.get("stream", False)
model = data.get("model", "RatChat")
chat = messageArray[-1]["content"] if messageArray else ""
msgobj = messageOBJ(chat)
msg_q.put(msgobj)
if stream:
return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"}
else:
return fullReturn(msgobj,model)
@app.route('/v1/models', methods=['GET'])
def list_models():
return jsonify({
"object": "list",
"data": [
{
"id": "RatChat",
"object": "model",
"created": int(time.time()),
"owned_by": "tinygrad"
}
]
})
def apiStart(): def apiStart():
""" start api """ """ start api """