import numpy as np import threading import queue def startDataWorker(dataset,encoding,batch_size,block_size,chat): data_q = queue.Queue(maxsize=100) t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size,chat), daemon=True) t.start() while (1): try: bx, by = data_q.get(timeout=30) except queue.Empty: print("queue empty ...") continue yield (bx,by) def dataWorker(q, dataset, encoding, batch_size, block_size,chat): batch_x, batch_y = [], [] while True: for text in dataset: tokens = None if(chat): txt="" for msg in text['messages']: role = msg['role'] content = msg['content'] txt = txt + f"<|{role}|>{content}<|end|>" tokens = [encoding.bos_token_id]+encoding.encode(txt) else: tokens = [encoding.bos_token_id]+encoding.encode(text["text"]) for i in range(0, len(tokens)-block_size+1,block_size): x = tokens[i:i+block_size] y = tokens[i+1:i+block_size+1] if len(x) < block_size: pad = len(x)-(block_size-1) x = x + [encoding.eos_token_id] + [encoding.pad_token_id] * pad if len(y) < block_size: pad = len(y)-(block_size-1) y = y + [encoding.eos_token_id] + [encoding.pad_token_id] * pad batch_x.append(x) batch_y.append(y) if len(batch_x) == batch_size: q.put((np.array(batch_x, dtype=np.int32), np.array(batch_y, dtype=np.int32))) batch_x, batch_y = [], []