#! /usr/bin/env nix-shell #! nix-shell -i python3 -p python3Packages.tinygrad python3Packages.numpy python3Packages.discordpy python3Packages.transformers python3Packages.tqdm python3Packages.flask import queue import flask from tinygrad import Tensor, TinyJit, dtypes, Device from tinygrad.nn.state import safe_load, load_state_dict from transformers import AutoTokenizer from model import Transformer from tqdm import tqdm import threading hypr = { "embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512, "encoding": "TinyLlama/TinyLlama_v1.1" } CHECKPOINT_PATH = 'gpt.safetensors' msg_q = queue.Queue() encoding = AutoTokenizer.from_pretrained(hypr['encoding']) model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hypr["n_blocks"], hypr["block_size"]) load_state_dict(model, safe_load(CHECKPOINT_PATH)) Tensor.training = False @TinyJit def run_model(input_buffer): """ run model on gpu """ return model(input_buffer) def inference_worker(): """ consume tasks from que """ BatchSize=2 NewList = [None] * BatchSize import time while True: if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList): i = NewList.index(None) out,inp = msg_q.get() NewList[i] = (out,inp,None) batch = [] for i in range(BatchSize): t = None if not NewList[i]: t = Tensor.zeros(hypr['block_size']) else: _, t, _ = NewList[i] if not isinstance(t, Tensor): t = Tensor(t) l = t.shape[0] pad_len = hypr['block_size'] - l a,b,_ = NewList[i] NewList[i] = (a,t,l) t = t.pad((0,pad_len)) else: #t = t[:-hypr['block_size']] l = t.shape[0] pad_len = hypr['block_size'] - l t = t.pad((0,pad_len)) batch.append(t) chat_tensor = batch[0].stack(*batch[1:]) #infince here logits = model(chat_tensor) #return for i in range(BatchSize): if NewList[i] is None: continue out, t, lenth = NewList[i] if lenth < 15: tok = (logits[i, lenth-1, :] / 0.7).softmax().multinomial(1) inp = t.cat(tok) out.put(tok.numpy()[0]) NewList[i] = (out,inp,(lenth+1)) else: print(encoding.decode(chat_tensor[i].numpy().astype(int))[:25]) out.shutdown() NewList[i] = None def warmup(count): """ run count times with random data """ import random tokens = encoding.encode("") tokens = Tensor([tokens]) for i in tqdm(range(count)): pad_len = hypr['block_size'] - tokens.shape[1] input_buffer = tokens.pad(((0, 0), (0, pad_len))).contiguous() out = model(input_buffer) token_tensor = (out[:, tokens.shape[1] - 1, :] / 0.7).softmax().multinomial(1) tokens = tokens.cat(token_tensor, dim=1).realize() tokens = tokens[:-hypr['block_size']] def test(msg): tokens = queue.Queue() inp = encoding.encode(msg) t = [] msg_q.put((tokens,inp)) yield("Start:") while True: try: i = tokens.get() t.append(i) yield(f"{i},") except: break txt = encoding.decode(t) yield f"\n{txt}" return app = flask.Flask(__name__) from flask import request @app.route('/',methods=['POST']) def complete(): user_string = request.form.get('input', 'Default prompt') return test(user_string),{"Content-Type": "text"} def apiStart(): """ start api """ app.run() pass if __name__ == "__main__": print(Device.DEFAULT) print("warming up") #warmup(200) t = threading.Thread(target=apiStart, daemon=True) t.start() inference_worker()