add status bar for epoch progress
This commit is contained in:
parent
579b37cd70
commit
b076a0d123
1 changed files with 2 additions and 1 deletions
3
train.py
3
train.py
|
|
@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
|
||||||
import time
|
import time
|
||||||
import show
|
import show
|
||||||
from model import gen
|
from model import gen
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 100
|
EPOCHS = 100
|
||||||
|
|
@ -48,7 +49,7 @@ eshape = (BATCH_SIZE, 1, 128, 216)
|
||||||
for epoch in range(0,EPOCHS):
|
for epoch in range(0,EPOCHS):
|
||||||
print(f"\n--- Starting Epoch {epoch} ---\n")
|
print(f"\n--- Starting Epoch {epoch} ---\n")
|
||||||
loss=0
|
loss=0
|
||||||
for i in range(0,len(x),BATCH_SIZE):
|
for i in tqdm(range(0,len(x),BATCH_SIZE)):
|
||||||
tx=Tensor(x[i:i+BATCH_SIZE])
|
tx=Tensor(x[i:i+BATCH_SIZE])
|
||||||
ty=Tensor(y[i:i+BATCH_SIZE])
|
ty=Tensor(y[i:i+BATCH_SIZE])
|
||||||
if(tx.shape != eshape):
|
if(tx.shape != eshape):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue