Compare commits

...

2 Commits

Author SHA1 Message Date
k
496916f428 added fine-tuning 2026-01-07 13:01:06 -05:00
k
121640bab6 updated hypr for my gpu 2026-01-07 12:59:44 -05:00
2 changed files with 40 additions and 19 deletions

24
data.py
View File

@@ -2,9 +2,9 @@ import numpy as np
import threading import threading
import queue import queue
def startDataWorker(dataset,encoding,batch_size,block_size): def startDataWorker(dataset,encoding,batch_size,block_size,chat):
data_q = queue.Queue(maxsize=100) data_q = queue.Queue(maxsize=100)
t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size), daemon=True) t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size,chat), daemon=True)
t.start() t.start()
while (1): while (1):
try: try:
@@ -14,14 +14,22 @@ def startDataWorker(dataset,encoding,batch_size,block_size):
continue continue
yield (bx,by) yield (bx,by)
def dataWorker(q, dataset, encoding, batch_size, block_size): def dataWorker(q, dataset, encoding, batch_size, block_size,chat):
batch_x, batch_y = [], [] batch_x, batch_y = [], []
while True: while True:
for text in dataset["text"]: for text in dataset:
tokens = encoding.encode(text) tokens = None
for i in range(0, len(tokens)-block_size-1,block_size): if(chat):
x = [encoding.bos_token_id] + tokens[i:i+block_size-1] txt = f"<|user|>{text['instruction']}"
y = tokens[i:i+block_size] if(text["input"] != None):
txt += f"\n{text['input']}"
txt = txt + f"<|end|>\n<|assistant|>{text['output']}<|end|>"
tokens = [encoding.bos_token_id]+encoding.encode(txt)
else:
tokens = [encoding.bos_token_id]+encoding.encode(text["text"])
for i in range(0, len(tokens)-block_size+1,block_size):
x = tokens[i:i+block_size]
y = tokens[i+1:i+block_size+1]
if len(x) < block_size: if len(x) < block_size:
pad = len(x)-(block_size-1) pad = len(x)-(block_size-1)

View File

@@ -1,30 +1,37 @@
from tinygrad.nn.state import get_state_dict,safe_load, load_state_dict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from tinygrad import Tensor,TinyJit,Device,nn from tinygrad import Tensor,TinyJit,Device,nn
from tinygrad.nn.state import get_state_dict
from model import Transformer
from transformers import AutoTokenizer from transformers import AutoTokenizer
from datasets import load_dataset from datasets import load_dataset
from model import Transformer
from tqdm import tqdm from tqdm import tqdm
import optm import optm
import data import data
import log import log
import sys
hypr = { hypr = {
"embed_size": 256, "embed_size": 512,
"n_heads": 4, "n_heads": 8,
"n_blocks": 4, "n_blocks": 6,
"block_size": 256, "block_size": 256,
"batch_size": 16, "batch_size": 16,
"starting_lr": 3e-4, "starting_lr": 6e-4,
"minimum_lr": 3e-5, "minimum_lr": 6e-5,
"warmup": 1_000, "warmup": 1_000,
"steps": 5_000, "steps": 20_000,
"encoding": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "encoding": "gpt2",
"dataset": "HuggingFaceTB/smollm-corpus", "dataset": "HuggingFaceTB/smollm-corpus",
"subset": "cosmopedia-v2", "subset": "cosmopedia-v2",
"chat_dataset": "yahma/alpaca-cleaned",
"chat_subset": None,
} }
print(Device.DEFAULT) print(Device.DEFAULT)
chat = len(sys.argv) > 1
if(chat):
hypr["dataset"] = hypr["chat_dataset"]
hypr["subset"] = hypr["chat_subset"]
#for loging #for loging
loger = ThreadPoolExecutor(max_workers=2) loger = ThreadPoolExecutor(max_workers=2)
@@ -34,10 +41,14 @@ dataset = load_dataset(hypr["dataset"],
split="train", split="train",
streaming=True) streaming=True)
encoding = AutoTokenizer.from_pretrained(hypr["encoding"]) encoding = AutoTokenizer.from_pretrained(hypr["encoding"])
if encoding.pad_token_id == None:
encoding.pad_token_id=encoding.eos_token_id
hypr["vocab_size"] = encoding.vocab_size hypr["vocab_size"] = encoding.vocab_size
model = Transformer(hypr["vocab_size"],hypr["embed_size"],hypr["n_heads"],hypr["n_blocks"],hypr["block_size"]) batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"],chat)
batch = data.startDataWorker(dataset,encoding,hypr["batch_size"],hypr["block_size"])
model = Transformer(hypr["vocab_size"],hypr["embed_size"],hypr["n_heads"],hypr["n_blocks"],hypr["block_size"])
if (chat):
load_state_dict(model,safe_load(sys.argv[1]))
params = nn.state.get_parameters(model) params = nn.state.get_parameters(model)
optimizer = optm.llmOptimizer(params,hypr["steps"],hypr["starting_lr"],hypr["minimum_lr"]) optimizer = optm.llmOptimizer(params,hypr["steps"],hypr["starting_lr"],hypr["minimum_lr"])
@@ -74,4 +85,6 @@ for steps in bar:
#TODO non sycronus safetensor loging #TODO non sycronus safetensor loging
#loger.submit(log.logModel,steps,m) #loger.submit(log.logModel,steps,m)
m = get_state_dict(model)
log.logModel("final",m)
loger.shutdown(wait=True) loger.shutdown(wait=True)