Files
llm/train.py
2026-01-09 17:30:34 -05:00

97 lines
2.7 KiB
Python

from tinygrad.nn.state import get_state_dict,safe_load, load_state_dict
from concurrent.futures import ThreadPoolExecutor
from tinygrad import Tensor,TinyJit,Device,nn
from transformers import AutoTokenizer
from datasets import load_dataset
from model import Transformer
from tqdm import tqdm
import optm
import data
import log
import sys
hypr = {
"embed_size": 768,
"n_heads": 12,
"n_blocks": 12,
"block_size": 512,
"batch_size": 8,
"starting_lr": 6e-4,
"minimum_lr": 6e-5,
"warmup": 1_000,
"steps": 20_000,
"encoding": "gpt2",
"dataset": "HuggingFaceTB/smollm-corpus",
"subset": "cosmopedia-v2",
"chat_dataset": "HuggingFaceTB/smoltalk",
"chat_subset": "all",
"half": False,
}
print(Device.DEFAULT)
chat = len(sys.argv) > 1
if(chat):
hypr["dataset"] = hypr["chat_dataset"]
hypr["subset"] = hypr["chat_subset"]
hypr["starting_lr"] *= 0.1
hypr["minimum_lr"] *= 0.1
#for loging
loger = ThreadPoolExecutor(max_workers=2)
dataset = load_dataset(hypr["dataset"],
hypr["subset"],
split="train",
streaming=True)
encoding = AutoTokenizer.from_pretrained(hypr["encoding"])
if encoding.pad_token_id == None:
encoding.pad_token_id=encoding.eos_token_id
hypr["vocab_size"] = encoding.vocab_size
batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"],chat)
model = Transformer(hypr["vocab_size"],hypr["embed_size"],hypr["n_heads"],hypr["n_blocks"],hypr["block_size"])
if (chat):
load_state_dict(model,safe_load(sys.argv[1]))
if hypr["half"]:
from tinygrad import dtypes
model = model.cast(dtypes.float16)
params = nn.state.get_parameters(model)
optimizer = optm.llmOptimizer(params,hypr["steps"],hypr["starting_lr"],hypr["minimum_lr"])
@TinyJit
def step(x,y):
optimizer.zero_grad()
logits = model(x)
B,T,C = logits.shape
logits = logits.view(B*T,C)
y = y.view(B*T)
loss = logits.cross_entropy(y)
loss.backward()
optimizer.step()
return loss
Tensor.training=True
bar = tqdm(range(hypr["steps"]))
for steps in bar:
nx, ny = next(batch)
x = Tensor(nx, device=Device.DEFAULT).realize()
y = Tensor(ny, device=Device.DEFAULT).realize()
loss = step(x, y)
if steps % 10 == 0:
l = loss.numpy()
loger.submit(log.logLoss, steps, l)
bar.set_postfix(loss= f"{l:.4f}")
if steps % 500 == 0:
loss.realize()
m = get_state_dict(model)
log.logModel(steps,m)
#TODO non sycronus safetensor loging
#loger.submit(log.logModel,steps,m)
m = get_state_dict(model)
log.logModel("final",m)
loger.shutdown(wait=True)