From 20891c4ddd05ae2d0cc4bcbc50f3fd9bf3fe157c Mon Sep 17 00:00:00 2001 From: k Date: Tue, 3 Mar 2026 10:25:11 -0500 Subject: [PATCH] Init --- bot.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100755 bot.py diff --git a/bot.py b/bot.py new file mode 100755 index 0000000..70000db --- /dev/null +++ b/bot.py @@ -0,0 +1,59 @@ +#! /usr/bin/env nix-shell +#! nix-shell -i python3 -p python3Packages.tinygrad python3Packages.numpy python3Packages.discordpy python3Packages.transformers python3Packages.tqdm python3Packages.flask + +import queue +import flask +from tinygrad import Tensor, TinyJit, dtypes, Device +from tinygrad.nn.state import safe_load, load_state_dict +from transformers import AutoTokenizer +from model import Transformer +from tqdm import tqdm + +hypr = { + "embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512, + "encoding": "TinyLlama/TinyLlama_v1.1" +} + +CHECKPOINT_PATH = 'gpt.safetensors' + +msg_q = queue.Queue() + +encoding = AutoTokenizer.from_pretrained(hypr['encoding']) +model = Transformer(encoding.vocab_size, hypr["embed_size"], hypr["n_heads"], hypr["n_blocks"], hypr["block_size"]) +load_state_dict(model, safe_load(CHECKPOINT_PATH)) +Tensor.training = False + + +@TinyJit +def run_model(input_buffer): + return model(input_buffer) + +def inference_worker(): + """ Runs in a separate thread to handle the heavy lifting. """ + pass + + +def warmup(count): + import random + tokens = encoding.encode("") + tokens = Tensor([tokens]) + for i in tqdm(range(count)): + pad_len = hypr['block_size'] - tokens.shape[1] + input_buffer = tokens.pad(((0, 0), (0, pad_len))).contiguous() + out = model(input_buffer) + token_tensor = (out[:, tokens.shape[1] - 1, :] / 0.7).softmax().multinomial(1) + tokens = tokens.cat(token_tensor, dim=1).realize() + tokens = tokens[:-hypr['block_size']] + + +def apiStart(): + pass + + +if __name__ == "__main__": + print(Device.DEFAULT) + print("warming up") + warmup(200) + t = threading.Thread(target=apiStart, daemon=True) + t.start() + inference_worker()