import numpy as np import threading import queue def startDataWorker(dataset,encoding,batch_size,block_size,chat): data_q = queue.Queue(maxsize=100) t = threading.Thread(target=dataWorker, args=(data_q, dataset, encoding, batch_size, block_size,chat), daemon=True) t.start() while (1): try: bx, by = data_q.get(timeout=30) except queue.Empty: print("queue empty ...") continue yield (bx,by) def dataWorker(q, dataset, encoding, batch_size, block_size,chat): batch_x, batch_y = [], [] while True: for text in dataset: tokens = None if(chat): txt = f"<|user|>{text['instruction']}" if(text["input"] != None): txt += f"\n{text['input']}" txt = txt + f"<|end|>\n<|assistant|>{text['output']}<|end|>" tokens = [encoding.bos_token_id]+encoding.encode(txt) else: tokens = [encoding.bos_token_id]+encoding.encode(text["text"]) for i in range(0, len(tokens)-block_size+1,block_size): x = tokens[i:i+block_size] y = tokens[i+1:i+block_size+1] if len(x) < block_size: pad = len(x)-(block_size-1) x = x + [encoding.eos_token_id] + [encoding.pad_token_id] * pad if len(y) < block_size: pad = len(y)-(block_size-1) y = y + [encoding.eos_token_id] + [encoding.pad_token_id] * pad batch_x.append(x) batch_y.append(y) if len(batch_x) == batch_size: q.put((np.array(batch_x, dtype=np.int32), np.array(batch_y, dtype=np.int32))) batch_x, batch_y = [], []