Quick training script

This commit is contained in:
k
2026-01-07 02:14:09 -05:00
parent 7f25dff1d1
commit 6f037c4a9a

77
train.py Normal file
View File

@@ -0,0 +1,77 @@
from concurrent.futures import ThreadPoolExecutor
from tinygrad import Tensor,TinyJit,Device,nn
from tinygrad.nn.state import get_state_dict
from model import Transformer
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import optm
import data
import log
hypr = {
"embed_size": 256,
"n_heads": 4,
"n_blocks": 4,
"block_size": 256,
"batch_size": 16,
"starting_lr": 3e-4,
"minimum_lr": 3e-5,
"warmup": 1_000,
"steps": 5_000,
"encoding": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"dataset": "HuggingFaceTB/smollm-corpus",
"subset": "cosmopedia-v2",
}
print(Device.DEFAULT)
#for loging
loger = ThreadPoolExecutor(max_workers=2)
dataset = load_dataset(hypr["dataset"],
hypr["subset"],
split="train",
streaming=True)
encoding = AutoTokenizer.from_pretrained(hypr["encoding"])
hypr["vocab_size"] = encoding.vocab_size
model = Transformer(hypr["vocab_size"],hypr["embed_size"],hypr["n_heads"],hypr["n_blocks"],hypr["block_size"])
batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"])
params = nn.state.get_parameters(model)
optimizer = optm.llmOptimizer(params,hypr["steps"],hypr["starting_lr"],hypr["minimum_lr"])
@TinyJit
def step(x,y):
optimizer.zero_grad()
logits = model(x)
B,T,C = logits.shape
logits = logits.view(B*T,C)
y = y.view(B*T)
loss = logits.cross_entropy(y)
loss.backward()
optimizer.step()
return loss
Tensor.training=True
bar = tqdm(range(hypr["steps"]))
for steps in bar:
nx, ny = next(batch)
x = Tensor(nx, device=Device.DEFAULT).realize()
y = Tensor(ny, device=Device.DEFAULT).realize()
loss = step(x, y)
if steps % 10 == 0:
l = loss.numpy()
loger.submit(log.logLoss, steps, l)
bar.set_postfix(loss= f"{l:.4f}")
if steps % 500 == 0:
loss.realize()
m = get_state_dict(model)
log.logModel(steps,m)
#TODO non sycronus safetensor loging
#loger.submit(log.logModel,steps,m)
loger.shutdown(wait=True)