added fine-tuning

This commit is contained in:
k
2026-01-07 13:01:06 -05:00
parent 121640bab6
commit 496916f428
2 changed files with 33 additions and 12 deletions

24
data.py
View File

@@ -2,9 +2,9 @@ import numpy as np
import threading
import queue
def startDataWorker(dataset,encoding,batch_size,block_size):
def startDataWorker(dataset,encoding,batch_size,block_size,chat):
data_q = queue.Queue(maxsize=100)
t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size), daemon=True)
t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size,chat), daemon=True)
t.start()
while (1):
try:
@@ -14,14 +14,22 @@ def startDataWorker(dataset,encoding,batch_size,block_size):
continue
yield (bx,by)
def dataWorker(q, dataset, encoding, batch_size, block_size):
def dataWorker(q, dataset, encoding, batch_size, block_size,chat):
batch_x, batch_y = [], []
while True:
for text in dataset["text"]:
tokens = encoding.encode(text)
for i in range(0, len(tokens)-block_size-1,block_size):
x = [encoding.bos_token_id] + tokens[i:i+block_size-1]
y = tokens[i:i+block_size]
for text in dataset:
tokens = None
if(chat):
txt = f"<|user|>{text['instruction']}"
if(text["input"] != None):
txt += f"\n{text['input']}"
txt = txt + f"<|end|>\n<|assistant|>{text['output']}<|end|>"
tokens = [encoding.bos_token_id]+encoding.encode(txt)
else:
tokens = [encoding.bos_token_id]+encoding.encode(text["text"])
for i in range(0, len(tokens)-block_size+1,block_size):
x = tokens[i:i+block_size]
y = tokens[i+1:i+block_size+1]
if len(x) < block_size:
pad = len(x)-(block_size-1)