import mlflow import numpy as np from tinygrad import Device,Tensor,nn,TinyJit from tinygrad.nn.state import safe_save, get_state_dict import matplotlib.pyplot as plt import time import show from model import gen from tqdm import tqdm BATCH_SIZE = 16 EPOCHS = 100 LEARNING_RATE = 3e-4 print(Device.DEFAULT) mdl = gen() opt = nn.optim.AdamW(nn.state.get_parameters(mdl), lr=LEARNING_RATE) def spec_loss(pred, target, eps=1e-6): # spectral convergence sc = ((target - pred).square().sum()) ** 0.5 / ((target.square().sum()) ** 0.5 + eps) # log magnitude difference log_mag = ((target.abs() + eps).log() - (pred.abs() + eps).log()).abs().mean() return 0.1*sc + 1.0*log_mag + 0.1*(pred - target).abs().mean() @TinyJit def step_gen(x,y): Tensor.training = True z = mdl(x) loss = spec_loss(z,y) #loss = (y - z).abs().mean() opt.zero_grad() loss.backward() opt.step() return loss.numpy() print("loading") x = np.load("data.npz")["arr_0"] y = np.load("data.npz")["arr_1"] run_name = f"vae_{int(time.time())}" mlflow.set_tracking_uri("http://127.0.0.1:5000") mlflow.start_run() mlflow.log_params({"batch_size": BATCH_SIZE, "epochs": EPOCHS, "lr": LEARNING_RATE, "data size":len(x)}) show.logSpec(Tensor(x[0:1]).numpy()[0][0],"default") print("training") eshape = (BATCH_SIZE, 1, 128, 216) for epoch in range(0,EPOCHS): print(f"\n--- Starting Epoch {epoch} ---\n") loss=0 for i in tqdm(range(0,len(x),BATCH_SIZE)): tx=Tensor(x[i:i+BATCH_SIZE]) ty=Tensor(y[i:i+BATCH_SIZE]) if(tx.shape != eshape): continue loss += step_gen(tx,ty) loss /= (len(x)/BATCH_SIZE) if epoch%5==0: show.logSpec(mdl(Tensor(x[0:1])).numpy()[0][0],epoch) if epoch%15==0: state_dict = get_state_dict(mdl) safe_save(state_dict, f"model_{epoch}.safetensors") show.logSpec(mdl(mdl(mdl(Tensor(y[0:1])))).numpy()[0][0],f"deep_{epoch}") mlflow.log_metric("loss", loss, step=epoch) print(f"loss of {loss}") show.logSpec(mdl(Tensor(x[0:1])).numpy()[0][0],EPOCHS) state_dict = get_state_dict(mdl) safe_save(state_dict, "model.safetensors")