From a37d05a2de9dbdc2002e89978545890351f3a666 Mon Sep 17 00:00:00 2001 From: k Date: Wed, 4 Mar 2026 00:28:46 -0500 Subject: [PATCH] basic version of OpenAI api implimented --- bot.py | 195 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 128 insertions(+), 67 deletions(-) diff --git a/bot.py b/bot.py index c5128b9..0c86086 100755 --- a/bot.py +++ b/bot.py @@ -3,13 +3,16 @@ import queue import flask +from flask import request, Response, jsonify from tinygrad import Tensor, TinyJit, dtypes, Device from tinygrad.nn.state import safe_load, load_state_dict from transformers import AutoTokenizer from model import Transformer from tqdm import tqdm import threading - +import json +import time +import uuid hypr = { "embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512, @@ -25,6 +28,41 @@ model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hy load_state_dict(model, safe_load(CHECKPOINT_PATH)) Tensor.training = False +class messageOBJ: + def __init__(self,chat,maxGen=50,temp=0.7): + self.outputQueue = queue.Queue() + self.chat = encoding.encode(chat) + self.ready = False + self.chatLen = 0 + self.maxGen = maxGen + self.runGen = 0 + self.temp = temp + + def __readyChat__(self): + """ Must only be called on tinygrad thread """ + if not self.ready: + self.chat = Tensor(self.chat) + self.ready = True + #TODO trim chat tensor + + def getTensor(self): + """ Must only be called on tinygrad thread """ + self.__readyChat__() + self.chatLen = self.chat.shape[0] + pad_len = hypr['block_size'] - self.chatLen + return self.chat.pad((0,pad_len)) + + def step(self): + self.runGen += 1 + return self.runGen < self.maxGen + + def add(self,token): + self.chat = self.chat.cat(token) + self.outputQueue.put(token.numpy()[0]) + + def finish(self): + self.outputQueue.put(None) + @TinyJit def run_model(input_buffer): @@ -33,15 +71,13 @@ def run_model(input_buffer): def inference_worker(): """ consume tasks from que """ - BatchSize=2 + BatchSize = 8 NewList = [None] * BatchSize - import time while True: if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList): i = NewList.index(None) - out,inp = msg_q.get() - NewList[i] = (out,inp,None) + NewList[i] = msg_q.get() batch = [] for i in range(BatchSize): @@ -49,82 +85,107 @@ def inference_worker(): if not NewList[i]: t = Tensor.zeros(hypr['block_size']) else: - _, t, _ = NewList[i] - if not isinstance(t, Tensor): - t = Tensor(t) - l = t.shape[0] - pad_len = hypr['block_size'] - l - a,b,_ = NewList[i] - NewList[i] = (a,t,l) - - t = t.pad((0,pad_len)) - else: - #t = t[:-hypr['block_size']] - l = t.shape[0] - pad_len = hypr['block_size'] - l - t = t.pad((0,pad_len)) + msgobj = NewList[i] + t = msgobj.getTensor() batch.append(t) - chat_tensor = batch[0].stack(*batch[1:]) - #infince here + chat_tensor = batch[0].stack(*batch[1:]) logits = model(chat_tensor) - #return for i in range(BatchSize): if NewList[i] is None: continue - out, t, lenth = NewList[i] - if lenth < 15: - tok = (logits[i, lenth-1, :] / 0.7).softmax().multinomial(1) - inp = t.cat(tok) - out.put(tok.numpy()[0]) - NewList[i] = (out,inp,(lenth+1)) + msgobj = NewList[i] + if msgobj.step(): + tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1) + msgobj.add(tok) else: - print(encoding.decode(chat_tensor[i].numpy().astype(int))[:25]) - out.shutdown() + msgobj.finish() NewList[i] = None - -def warmup(count): - """ run count times with random data """ - import random - tokens = encoding.encode("") - tokens = Tensor([tokens]) - for i in tqdm(range(count)): - pad_len = hypr['block_size'] - tokens.shape[1] - input_buffer = tokens.pad(((0, 0), (0, pad_len))).contiguous() - out = model(input_buffer) - token_tensor = (out[:, tokens.shape[1] - 1, :] / 0.7).softmax().multinomial(1) - tokens = tokens.cat(token_tensor, dim=1).realize() - tokens = tokens[:-hypr['block_size']] - -def test(msg): - tokens = queue.Queue() - inp = encoding.encode(msg) - t = [] - - msg_q.put((tokens,inp)) - yield("Start:") - while True: - try: - i = tokens.get() - t.append(i) - yield(f"{i},") - except: - break - txt = encoding.decode(t) - yield f"\n{txt}" - return - app = flask.Flask(__name__) -from flask import request -@app.route('/',methods=['POST']) -def complete(): - user_string = request.form.get('input', 'Default prompt') - return test(user_string),{"Content-Type": "text"} +def fullReturn(msgobj,model): + ids = [] + while True: + token = msgobj.outputQueue.get() + if token is None: + break + ids.append(token) + + return jsonify({ + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": encoding.decode(ids) + }, + "finish_reason": "stop" + }] + }) + +def chunkReturn(msgobj,model): + chat_id = f"chatcmpl-{uuid.uuid4()}" + ids = [] + old = "" + while True: + token_id = msgobj.outputQueue.get() + if token_id is None: + break + ids.append(token_id) + tmp = encoding.decode(ids) + word = tmp[len(old):] + old = tmp + + chunk = { + "id": chat_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": word}, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + +@app.route('/v1/chat/completions', methods=['POST']) +def completions(): + data = request.json + messageArray = data.get("messages", []) + stream = data.get("stream", False) + model = data.get("model", "RatChat") + chat = messageArray[-1]["content"] if messageArray else "" + msgobj = messageOBJ(chat) + msg_q.put(msgobj) + + if stream: + return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"} + else: + return fullReturn(msgobj,model) + +@app.route('/v1/models', methods=['GET']) +def list_models(): + return jsonify({ + "object": "list", + "data": [ + { + "id": "RatChat", + "object": "model", + "created": int(time.time()), + "owned_by": "tinygrad" + } + ] + }) def apiStart(): """ start api """