from tinygrad.nn.state import get_state_dict,safe_load, load_state_dict from concurrent.futures import ThreadPoolExecutor from tinygrad import Tensor,TinyJit,Device,nn from transformers import AutoTokenizer from datasets import load_dataset from model import Transformer from tqdm import tqdm import optm import data import log import sys hypr = { "embed_size": 768, "n_heads": 12, "n_blocks": 12, "block_size": 512, "batch_size": 8, "starting_lr": 6e-4, "minimum_lr": 6e-5, "warmup": 1_000, "steps": 20_000, "encoding": "gpt2", "dataset": "HuggingFaceTB/smollm-corpus", "subset": "cosmopedia-v2", "chat_dataset": "HuggingFaceTB/smoltalk", "chat_subset": "all", "half": False, } print(Device.DEFAULT) chat = len(sys.argv) > 1 if(chat): hypr["dataset"] = hypr["chat_dataset"] hypr["subset"] = hypr["chat_subset"] hypr["starting_lr"] *= 0.1 hypr["minimum_lr"] *= 0.1 #for loging loger = ThreadPoolExecutor(max_workers=2) dataset = load_dataset(hypr["dataset"], hypr["subset"], split="train", streaming=True) encoding = AutoTokenizer.from_pretrained(hypr["encoding"]) if encoding.pad_token_id == None: encoding.pad_token_id=encoding.eos_token_id hypr["vocab_size"] = encoding.vocab_size batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"],chat) model = Transformer(hypr["vocab_size"],hypr["embed_size"],hypr["n_heads"],hypr["n_blocks"],hypr["block_size"]) if (chat): load_state_dict(model,safe_load(sys.argv[1])) if hypr["half"]: from tinygrad import dtypes model = model.cast(dtypes.float16) params = nn.state.get_parameters(model) optimizer = optm.llmOptimizer(params,hypr["steps"],hypr["starting_lr"],hypr["minimum_lr"]) @TinyJit def step(x,y): optimizer.zero_grad() logits = model(x) B,T,C = logits.shape logits = logits.view(B*T,C) y = y.view(B*T) loss = logits.cross_entropy(y) loss.backward() optimizer.step() return loss Tensor.training=True bar = tqdm(range(hypr["steps"])) for steps in bar: nx, ny = next(batch) x = Tensor(nx, device=Device.DEFAULT).realize() y = Tensor(ny, device=Device.DEFAULT).realize() loss = step(x, y) if steps % 10 == 0: l = loss.numpy() loger.submit(log.logLoss, steps, l) bar.set_postfix(loss= f"{l:.4f}") if steps % 500 == 0: loss.realize() m = get_state_dict(model) log.logModel(steps,m) #TODO non sycronus safetensor loging #loger.submit(log.logModel,steps,m) m = get_state_dict(model) log.logModel("final",m) loger.shutdown(wait=True)