diff --git a/model.py b/model.py index 06488f4..22c13db 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,5 @@ from tinygrad import Tensor, nn +import numpy as np class Model: def __init__(self, input_channels=1, height=128, width=216, latent_dim=32): @@ -30,10 +31,16 @@ class Model: mu, log_var = self.encode(inp) z = self.reparameterize(mu, log_var) recon = self.decode(z) - - recon_loss = (otp - recon).pow(2).sum() - kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum() - total_loss = recon_loss + kl_div + + # Normalized MSE (per-pixel) + recon_loss = ((otp - recon).pow(2).mean()) + + # Stabilized KL + kl_div = -0.5 * (1 + log_var.clip(-10, 10) - mu.pow(2) - log_var.clip(-10, 10).exp()).mean() + kl_div = kl_div.relu() + + # Weighted loss + total_loss = recon_loss + 0.5 * kl_div return recon, total_loss def encode(self, x: Tensor) -> (Tensor, Tensor): @@ -44,6 +51,7 @@ class Model: return self.fc_mu(x), self.fc_logvar(x) def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor: + log_var = log_var.clip(-10, 10) std = (log_var * 0.5).exp() eps = Tensor.randn(mu.shape) return mu + std * eps @@ -52,6 +60,6 @@ class Model: x = self.dl(x).relu() x = x.reshape(shape=(-1, 128, self.h, self.w)) x = self.d1(x).relu() - x = self.d2(x).sigmoid() - x = self.d3(x) + x = self.d2(x).relu() + x = self.d3(x).sigmoid() return x \ No newline at end of file