import numpy as np import threading import queue def startDataWorker(dataset,encoding,batch_size,block_size): data_q = queue.Queue(maxsize=100) t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size), daemon=True) t.start() while (1): try: bx, by = data_q.get(timeout=30) except queue.Empty: print("queue empty ...") continue yield (bx,by) def dataWorker(q, dataset, encoding, batch_size, block_size): batch_x, batch_y = [], [] while True: for text in dataset["text"]: tokens = encoding.encode(text) for i in range(0, len(tokens)-block_size-1,block_size): x = [encoding.bos_token_id] + tokens[i:i+block_size-1] y = tokens[i:i+block_size] if len(x) < block_size: pad = len(x)-(block_size-1) x = x + [encoding.eos_token_id] + [encoding.pad_token_id] * pad if len(y) < block_size: pad = len(y)-(block_size-1) y = y + [encoding.eos_token_id] + [encoding.pad_token_id] * pad batch_x.append(x) batch_y.append(y) if len(batch_x) == batch_size: q.put((np.array(batch_x, dtype=np.int32), np.array(batch_y, dtype=np.int32))) batch_x, batch_y = [], []