InfiniteMusic/train.py
2025-09-11 17:07:15 -04:00

107 lines
2.5 KiB
Python

#!/usr/bin/env python
# coding: utf-8
import data
import model as model
import show
import mlflow
import numpy as np
from tinygrad import nn,TinyJit,Tensor
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.start_run(experiment_id=804883409598823668)
#hyper
BACH_SIZE=32
BATCH_SIZE=BACH_SIZE
glr=2e-4
dlr=1e-5
epochs=100
#dataset
x = data.load()
size=len(x)
x_np = np.stack(x)
x_np = np.expand_dims(x_np, axis=1)
permutation = np.random.permutation(size)
x_np = x_np[permutation]
train = x_np[30:]
test = x_np[0:30]
print("Train:"+str(len(train)))
print("Test:"+str(len(test)))
#model
gen = model.Gen()
dif = model.Check()
genOpt = nn.optim.AdamW(nn.state.get_parameters(gen), lr=glr)
difOpt = nn.optim.AdamW(nn.state.get_parameters(dif), lr=dlr)
#train
@TinyJit
def step_dis(x:Tensor):
Tensor.training = True
real = Tensor.ones((BATCH_SIZE,1))
fake = Tensor.zeros((BACH_SIZE,1))
noise = Tensor.randn(BACH_SIZE, gen.ld)
fake_data = gen(noise).detach()
fake_loss = dif(fake_data).binary_crossentropy_logits(fake)
real_loss = dif(x).binary_crossentropy_logits(real)
loss = (fake_loss + real_loss)/2
loss.backward()
difOpt.step()
return loss.numpy()
@TinyJit
def step_gen():
Tensor.training = True
real = Tensor.ones((BATCH_SIZE,1))
noise = Tensor.randn(BACH_SIZE, gen.ld)
fake_data = gen(noise).detach()
loss = dif(fake_data).binary_crossentropy_logits(real)
loss.backward()
genOpt.step()
return loss.numpy()
eshape = (BACH_SIZE, 1, 128, 216)
mlflow.log_param("generator_learning_rate", glr)
mlflow.log_param("discim_learning_rate", dlr)
mlflow.log_param("epochs", epochs)
mlflow.log_param("train size", len(train))
mlflow.log_param("test size", len(test))
for e in range(0,epochs):
print(f"\n--- Starting Epoch {e} ---\n")
dl=0
gl=0
for i in range(0,size,BACH_SIZE):
tx=Tensor(train[i:i+BACH_SIZE])
if(tx.shape != eshape):
continue
#steps
dl+=step_dis(tx)
gl+=step_gen()
dl /= (size/BACH_SIZE)
gl /= (size/BACH_SIZE)
if e%5==0:
noise = Tensor.randn(BACH_SIZE, gen.ld)
show.logSpec(gen(noise).numpy()[0][0],e)
#todo test on test data
mlflow.log_metric("gen_loss", gl, step=e)
mlflow.log_metric("dis_loss", dl, step=e)
print(f"loss of gen:{gl} dis:{dl}")
#save
noise = Tensor.randn(BACH_SIZE, gen.ld)
show.logSpec(gen(noise).numpy()[0][0],epochs)
from tinygrad.nn.state import safe_save, get_state_dict
safe_save(get_state_dict(gen),"music.safetensors")
mlflow.log_artifact("music.safetensors")