Basic version working.

This commit is contained in:
k
2026-03-03 21:52:30 -05:00
parent 56cff9c37a
commit 4196b70681
2 changed files with 172 additions and 4 deletions

89
bot.py
View File

@@ -8,6 +8,8 @@ from tinygrad.nn.state import safe_load, load_state_dict
from transformers import AutoTokenizer
from model import Transformer
from tqdm import tqdm
import threading
hypr = {
"embed_size": 768, "n_heads": 8, "n_blocks": 12, "block_size": 512,
@@ -26,14 +28,67 @@ Tensor.training = False
@TinyJit
def run_model(input_buffer):
""" run model on gpu """
return model(input_buffer)
def inference_worker():
""" Runs in a separate thread to handle the heavy lifting. """
pass
""" consume tasks from que """
BatchSize=2
NewList = [None] * BatchSize
import time
while True:
if (not msg_q.empty() and None in NewList) or NewList.count(None) == len(NewList):
i = NewList.index(None)
out,inp = msg_q.get()
NewList[i] = (out,inp,None)
batch = []
for i in range(BatchSize):
t = None
if not NewList[i]:
t = Tensor.zeros(hypr['block_size'])
else:
_, t, _ = NewList[i]
if not isinstance(t, Tensor):
t = Tensor(t)
l = t.shape[0]
pad_len = hypr['block_size'] - l
a,b,_ = NewList[i]
NewList[i] = (a,t,l)
t = t.pad((0,pad_len))
else:
#t = t[:-hypr['block_size']]
l = t.shape[0]
pad_len = hypr['block_size'] - l
t = t.pad((0,pad_len))
batch.append(t)
chat_tensor = batch[0].stack(*batch[1:])
#infince here
logits = model(chat_tensor)
#return
for i in range(BatchSize):
if NewList[i] is None:
continue
out, t, lenth = NewList[i]
if lenth < 15:
tok = (logits[i, lenth-1, :] / 0.7).softmax().multinomial(1)
inp = t.cat(tok)
out.put(tok.numpy()[0])
NewList[i] = (out,inp,(lenth+1))
else:
print(encoding.decode(chat_tensor[i].numpy().astype(int))[:25])
out.shutdown()
NewList[i] = None
def warmup(count):
""" run count times with random data """
import random
tokens = encoding.encode("")
tokens = Tensor([tokens])
@@ -45,15 +100,41 @@ def warmup(count):
tokens = tokens.cat(token_tensor, dim=1).realize()
tokens = tokens[:-hypr['block_size']]
def test(msg):
tokens = queue.Queue()
inp = encoding.encode(msg)
t = []
msg_q.put((tokens,inp))
yield("Start:")
while True:
try:
i = tokens.get()
t.append(i)
yield(f"{i},")
except:
break
txt = encoding.decode(t)
yield f"\n{txt}"
return
app = flask.Flask(__name__)
from flask import request
@app.route('/',methods=['POST'])
def complete():
user_string = request.form.get('input', 'Default prompt')
return test(user_string),{"Content-Type": "text"}
def apiStart():
""" start api """
app.run()
pass
if __name__ == "__main__":
print(Device.DEFAULT)
print("warming up")
warmup(200)
#warmup(200)
t = threading.Thread(target=apiStart, daemon=True)
t.start()
inference_worker()