diff --git a/data.py b/data.py index 0e50c24..56a893a 100644 --- a/data.py +++ b/data.py @@ -2,9 +2,9 @@ import numpy as np import threading import queue -def startDataWorker(dataset,encoding,batch_size,block_size,chat): +def startDataWorker(dataset,encoding,batch_size,block_size): data_q = queue.Queue(maxsize=100) - t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size,chat), daemon=True) + t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size), daemon=True) t.start() while (1): try: @@ -14,22 +14,14 @@ def startDataWorker(dataset,encoding,batch_size,block_size,chat): continue yield (bx,by) -def dataWorker(q, dataset, encoding, batch_size, block_size,chat): +def dataWorker(q, dataset, encoding, batch_size, block_size): batch_x, batch_y = [], [] while True: - for text in dataset: - tokens = None - if(chat): - txt = f"<|user|>{text['instruction']}" - if(text["input"] != None): - txt += f"\n{text['input']}" - txt = txt + f"<|end|>\n<|assistant|>{text['output']}<|end|>" - tokens = [encoding.bos_token_id]+encoding.encode(txt) - else: - tokens = [encoding.bos_token_id]+encoding.encode(text["text"]) - for i in range(0, len(tokens)-block_size+1,block_size): - x = tokens[i:i+block_size] - y = tokens[i+1:i+block_size+1] + for text in dataset["text"]: + tokens = encoding.encode(text) + for i in range(0, len(tokens)-block_size-1,block_size): + x = [encoding.bos_token_id] + tokens[i:i+block_size-1] + y = tokens[i:i+block_size] if len(x) < block_size: pad = len(x)-(block_size-1) diff --git a/train.py b/train.py index a2421d5..f231764 100644 --- a/train.py +++ b/train.py @@ -1,37 +1,30 @@ -from tinygrad.nn.state import get_state_dict,safe_load, load_state_dict from concurrent.futures import ThreadPoolExecutor from tinygrad import Tensor,TinyJit,Device,nn +from tinygrad.nn.state import get_state_dict +from model import Transformer from transformers import AutoTokenizer from datasets import load_dataset -from model import Transformer from tqdm import tqdm import optm import data import log -import sys hypr = { - "embed_size": 512, - "n_heads": 8, - "n_blocks": 6, + "embed_size": 256, + "n_heads": 4, + "n_blocks": 4, "block_size": 256, "batch_size": 16, - "starting_lr": 6e-4, - "minimum_lr": 6e-5, + "starting_lr": 3e-4, + "minimum_lr": 3e-5, "warmup": 1_000, - "steps": 20_000, - "encoding": "gpt2", + "steps": 5_000, + "encoding": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "dataset": "HuggingFaceTB/smollm-corpus", "subset": "cosmopedia-v2", - "chat_dataset": "yahma/alpaca-cleaned", - "chat_subset": None, } print(Device.DEFAULT) -chat = len(sys.argv) > 1 -if(chat): - hypr["dataset"] = hypr["chat_dataset"] - hypr["subset"] = hypr["chat_subset"] #for loging loger = ThreadPoolExecutor(max_workers=2) @@ -41,14 +34,10 @@ dataset = load_dataset(hypr["dataset"], split="train", streaming=True) encoding = AutoTokenizer.from_pretrained(hypr["encoding"]) -if encoding.pad_token_id == None: - encoding.pad_token_id=encoding.eos_token_id hypr["vocab_size"] = encoding.vocab_size -batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"],chat) - model = Transformer(hypr["vocab_size"],hypr["embed_size"],hypr["n_heads"],hypr["n_blocks"],hypr["block_size"]) -if (chat): - load_state_dict(model,safe_load(sys.argv[1])) +batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"]) + params = nn.state.get_parameters(model) optimizer = optm.llmOptimizer(params,hypr["steps"],hypr["starting_lr"],hypr["minimum_lr"]) @@ -85,6 +74,4 @@ for steps in bar: #TODO non sycronus safetensor loging #loger.submit(log.logModel,steps,m) -m = get_state_dict(model) -log.logModel("final",m) loger.shutdown(wait=True)