73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import mlflow
|
|
import numpy as np
|
|
from tinygrad import Device,Tensor,nn,TinyJit
|
|
from tinygrad.nn.state import safe_save, get_state_dict
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
import show
|
|
from model import gen
|
|
from tqdm import tqdm
|
|
|
|
BATCH_SIZE = 16
|
|
EPOCHS = 100
|
|
LEARNING_RATE = 3e-4
|
|
print(Device.DEFAULT)
|
|
mdl = gen()
|
|
opt = nn.optim.AdamW(nn.state.get_parameters(mdl), lr=LEARNING_RATE)
|
|
|
|
def spec_loss(pred, target, eps=1e-6):
|
|
# spectral convergence
|
|
sc = ((target - pred).square().sum()) ** 0.5 / ((target.square().sum()) ** 0.5 + eps)
|
|
# log magnitude difference
|
|
log_mag = ((target.abs() + eps).log() - (pred.abs() + eps).log()).abs().mean()
|
|
return 0.1*sc + 1.0*log_mag + 0.1*(pred - target).abs().mean()
|
|
|
|
|
|
@TinyJit
|
|
def step_gen(x,y):
|
|
Tensor.training = True
|
|
z = mdl(x)
|
|
loss = spec_loss(z,y)
|
|
#loss = (y - z).abs().mean()
|
|
opt.zero_grad()
|
|
loss.backward()
|
|
opt.step()
|
|
return loss.numpy()
|
|
|
|
print("loading")
|
|
x = np.load("data.npz")["arr_0"]
|
|
y = np.load("data.npz")["arr_1"]
|
|
run_name = f"vae_{int(time.time())}"
|
|
mlflow.set_tracking_uri("http://127.0.0.1:5000")
|
|
mlflow.start_run()
|
|
mlflow.log_params({"batch_size": BATCH_SIZE, "epochs": EPOCHS, "lr": LEARNING_RATE, "data size":len(x)})
|
|
|
|
show.logSpec(Tensor(x[0:1]).numpy()[0][0],"default")
|
|
|
|
print("training")
|
|
eshape = (BATCH_SIZE, 1, 128, 216)
|
|
for epoch in range(0,EPOCHS):
|
|
print(f"\n--- Starting Epoch {epoch} ---\n")
|
|
loss=0
|
|
for i in tqdm(range(0,len(x),BATCH_SIZE)):
|
|
tx=Tensor(x[i:i+BATCH_SIZE])
|
|
ty=Tensor(y[i:i+BATCH_SIZE])
|
|
if(tx.shape != eshape):
|
|
continue
|
|
loss += step_gen(tx,ty)
|
|
|
|
loss /= (len(x)/BATCH_SIZE)
|
|
if epoch%5==0:
|
|
show.logSpec(mdl(Tensor(x[0:1])).numpy()[0][0],epoch)
|
|
if epoch%15==0:
|
|
state_dict = get_state_dict(mdl)
|
|
safe_save(state_dict, f"model_{epoch}.safetensors")
|
|
show.logSpec(mdl(mdl(mdl(Tensor(y[0:1])))).numpy()[0][0],f"deep_{epoch}")
|
|
|
|
mlflow.log_metric("loss", loss, step=epoch)
|
|
print(f"loss of {loss}")
|
|
|
|
show.logSpec(mdl(Tensor(x[0:1])).numpy()[0][0],EPOCHS)
|
|
state_dict = get_state_dict(mdl)
|
|
safe_save(state_dict, "model.safetensors")
|