50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
import numpy as np
|
|
import threading
|
|
import queue
|
|
|
|
def startDataWorker(dataset,encoding,batch_size,block_size,chat):
|
|
data_q = queue.Queue(maxsize=100)
|
|
t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size,chat), daemon=True)
|
|
t.start()
|
|
while (1):
|
|
try:
|
|
bx, by = data_q.get(timeout=30)
|
|
except queue.Empty:
|
|
print("queue empty ...")
|
|
continue
|
|
yield (bx,by)
|
|
|
|
def dataWorker(q, dataset, encoding, batch_size, block_size,chat):
|
|
batch_x, batch_y = [], []
|
|
while True:
|
|
for text in dataset:
|
|
tokens = None
|
|
if(chat):
|
|
txt=""
|
|
for msg in text['messages']:
|
|
role = msg['role']
|
|
content = msg['content']
|
|
txt = txt + f"<|{role}|>{content}<|end|>"
|
|
tokens = [encoding.bos_token_id]+encoding.encode(txt)
|
|
else:
|
|
tokens = [encoding.bos_token_id]+encoding.encode(text["text"])
|
|
for i in range(0, len(tokens)-block_size+1,block_size):
|
|
x = tokens[i:i+block_size]
|
|
y = tokens[i+1:i+block_size+1]
|
|
|
|
if len(x) < block_size:
|
|
pad = len(x)-(block_size-1)
|
|
x = x + [encoding.eos_token_id] + [encoding.pad_token_id] * pad
|
|
|
|
if len(y) < block_size:
|
|
pad = len(y)-(block_size-1)
|
|
y = y + [encoding.eos_token_id] + [encoding.pad_token_id] * pad
|
|
|
|
batch_x.append(x)
|
|
batch_y.append(y)
|
|
|
|
if len(batch_x) == batch_size:
|
|
q.put((np.array(batch_x, dtype=np.int32),
|
|
np.array(batch_y, dtype=np.int32)))
|
|
batch_x, batch_y = [], []
|