Fix errors

This commit is contained in:
k
2026-01-07 02:13:08 -05:00
parent 007c96e91b
commit 7f25dff1d1
3 changed files with 20 additions and 17 deletions

13
data.py
View File

@@ -4,18 +4,19 @@ import queue
def startDataWorker(dataset,encoding,batch_size,block_size):
data_q = queue.Queue(maxsize=100)
t = threading.Thread(target=data_worker, args=(data_q, dataset, encoding, batch_size, block_size), daemon=True)
t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size), daemon=True)
t.start()
while (1):
try:
bx, by = data_q.get(timeout=30)
except queue.Empty:
print("queue empty ...")
continue
yield (bx,by)
def dataWorker(q, dataset, encoding, batch_size, block_size):
batch_x, batch_y = [], []
while(1):
while True:
for text in dataset["text"]:
tokens = encoding.encode(text)
for i in range(0, len(tokens)-block_size-1,block_size):
@@ -33,7 +34,7 @@ def dataWorker(q, dataset, encoding, batch_size, block_size):
batch_x.append(x)
batch_y.append(y)
if len(batch_x) == batch_size:
q.put((np.array(batch_x, dtype=np.int32),
np.array(batch_y, dtype=np.int32)))
batch_x, batch_y = [], []
if len(batch_x) == batch_size:
q.put((np.array(batch_x, dtype=np.int32),
np.array(batch_y, dtype=np.int32)))
batch_x, batch_y = [], []