#! /usr/bin/env nix-shell #! nix-shell -i python3 -p python3Packages.tinygrad python3Packages.numpy python3Packages.discordpy python3Packages.transformers python3Packages.tqdm python3Packages.flask import queue import flask from flask import request, Response, jsonify from tinygrad import Tensor, TinyJit, dtypes, Device from tinygrad.nn.state import safe_load, load_state_dict from transformers import AutoTokenizer from model import Transformer from tqdm import tqdm import threading import json import time import uuid hypr = { "embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512, "encoding": "TinyLlama/TinyLlama_v1.1" } CHECKPOINT_PATH = 'gpt.safetensors' msg_q = queue.Queue() encoding = AutoTokenizer.from_pretrained(hypr['encoding']) model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hypr["n_blocks"], hypr["block_size"]) load_state_dict(model, safe_load(CHECKPOINT_PATH)) Tensor.training = False class messageOBJ: def __init__(self,chat,maxGen=50,temp=0.7): self.outputQueue = queue.Queue() self.chat = encoding.encode(chat) self.ready = False self.chatLen = 0 self.maxGen = maxGen self.runGen = 0 self.temp = temp def __readyChat__(self): """ Must only be called on tinygrad thread """ if not self.ready: self.chat = Tensor(self.chat) self.ready = True #TODO trim chat tensor def getTensor(self): """ Must only be called on tinygrad thread """ self.__readyChat__() self.chatLen = self.chat.shape[0] pad_len = hypr['block_size'] - self.chatLen return self.chat.pad((0,pad_len)) def step(self): self.runGen += 1 return self.runGen < self.maxGen def add(self,token): self.chat = self.chat.cat(token) self.outputQueue.put(token.numpy()[0]) def finish(self): self.outputQueue.put(None) @TinyJit def run_model(input_buffer): """ run model on gpu """ return model(input_buffer) def inference_worker(): """ consume tasks from que """ BatchSize = 8 NewList = [None] * BatchSize while True: if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList): i = NewList.index(None) NewList[i] = msg_q.get() batch = [] for i in range(BatchSize): t = None if not NewList[i]: t = Tensor.zeros(hypr['block_size']) else: msgobj = NewList[i] t = msgobj.getTensor() batch.append(t) chat_tensor = batch[0].stack(*batch[1:]) logits = model(chat_tensor) for i in range(BatchSize): if NewList[i] is None: continue msgobj = NewList[i] if msgobj.step(): tok = (logits[i, msgobj.chatLen-1, :] / msgobj.temp).softmax().multinomial(1) msgobj.add(tok) else: msgobj.finish() NewList[i] = None app = flask.Flask(__name__) def fullReturn(msgobj,model): ids = [] while True: token = msgobj.outputQueue.get() if token is None: break ids.append(token) return jsonify({ "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": encoding.decode(ids) }, "finish_reason": "stop" }] }) def chunkReturn(msgobj,model): chat_id = f"chatcmpl-{uuid.uuid4()}" ids = [] old = "" while True: token_id = msgobj.outputQueue.get() if token_id is None: break ids.append(token_id) tmp = encoding.decode(ids) word = tmp[len(old):] old = tmp chunk = { "id": chat_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "delta": {"content": word}, "finish_reason": None }] } yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" @app.route('/v1/chat/completions', methods=['POST']) def completions(): data = request.json messageArray = data.get("messages", []) stream = data.get("stream", False) model = data.get("model", "RatChat") chat = messageArray[-1]["content"] if messageArray else "" msgobj = messageOBJ(chat) msg_q.put(msgobj) if stream: return chunkReturn(msgobj,model), {"Content-Type:":"text/event-stream"} else: return fullReturn(msgobj,model) @app.route('/v1/models', methods=['GET']) def list_models(): return jsonify({ "object": "list", "data": [ { "id": "RatChat", "object": "model", "created": int(time.time()), "owned_by": "tinygrad" } ] }) def apiStart(): """ start api """ app.run() pass if __name__ == "__main__": print(Device.DEFAULT) print("warming up") #warmup(200) t = threading.Thread(target=apiStart, daemon=True) t.start() inference_worker()