InfiniteMusic/model.py
2025-07-25 18:34:58 -04:00

57 lines
2.2 KiB
Python

from tinygrad import Tensor, nn
class Model:
def __init__(self, input_channels=1, height=128, width=216, latent_dim=32):
self.w = width // 8
self.h = height // 8
self.flattened_size = 128 * self.h * self.w
# Encoder
self.e1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
self.e2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.e3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
# VAE Latent Space
self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)
# Decoder
self.dl = nn.Linear(latent_dim, self.flattened_size)
self.d1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1)
self.d2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1,output_padding=1)
self.d3 = nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1)
def __call__(self, x: Tensor) -> Tensor:
mu, log_var = self.encode(x)
x = self.reparameterize(mu, log_var)
return self.decode(x)
def __Lcall__(self, inp: Tensor, otp) -> (Tensor, Tensor):
mu, log_var = self.encode(inp)
z = self.reparameterize(mu, log_var)
recon = self.decode(z)
recon_loss = (otp - recon).pow(2).sum()
kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum()
total_loss = recon_loss + kl_div
return recon, total_loss
def encode(self, x: Tensor) -> (Tensor, Tensor):
x = self.e1(x).relu()
x = self.e2(x).relu()
x = self.e3(x).relu()
x = x.reshape(shape=(-1, self.flattened_size))
return self.fc_mu(x), self.fc_logvar(x)
def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
std = (log_var * 0.5).exp()
eps = Tensor.randn(mu.shape)
return mu + std * eps
def decode(self, x: Tensor) -> Tensor:
x = self.dl(x).relu()
x = x.reshape(shape=(-1, 128, self.h, self.w))
x = self.d1(x).relu()
x = self.d2(x).sigmoid()
x = self.d3(x)
return x