diff --git a/train.py b/train.py index e912da5..43f91e7 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import time import show from model import gen +from tqdm import tqdm BATCH_SIZE = 16 EPOCHS = 100 @@ -48,7 +49,7 @@ eshape = (BATCH_SIZE, 1, 128, 216) for epoch in range(0,EPOCHS): print(f"\n--- Starting Epoch {epoch} ---\n") loss=0 - for i in range(0,len(x),BATCH_SIZE): + for i in tqdm(range(0,len(x),BATCH_SIZE)): tx=Tensor(x[i:i+BATCH_SIZE]) ty=Tensor(y[i:i+BATCH_SIZE]) if(tx.shape != eshape):