improved model
This commit is contained in:
parent
b227e9515d
commit
a28901fdfd
18
model.py
18
model.py
@ -1,4 +1,5 @@
|
|||||||
from tinygrad import Tensor, nn
|
from tinygrad import Tensor, nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self, input_channels=1, height=128, width=216, latent_dim=32):
|
def __init__(self, input_channels=1, height=128, width=216, latent_dim=32):
|
||||||
@ -31,9 +32,15 @@ class Model:
|
|||||||
z = self.reparameterize(mu, log_var)
|
z = self.reparameterize(mu, log_var)
|
||||||
recon = self.decode(z)
|
recon = self.decode(z)
|
||||||
|
|
||||||
recon_loss = (otp - recon).pow(2).sum()
|
# Normalized MSE (per-pixel)
|
||||||
kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum()
|
recon_loss = ((otp - recon).pow(2).mean())
|
||||||
total_loss = recon_loss + kl_div
|
|
||||||
|
# Stabilized KL
|
||||||
|
kl_div = -0.5 * (1 + log_var.clip(-10, 10) - mu.pow(2) - log_var.clip(-10, 10).exp()).mean()
|
||||||
|
kl_div = kl_div.relu()
|
||||||
|
|
||||||
|
# Weighted loss
|
||||||
|
total_loss = recon_loss + 0.5 * kl_div
|
||||||
return recon, total_loss
|
return recon, total_loss
|
||||||
|
|
||||||
def encode(self, x: Tensor) -> (Tensor, Tensor):
|
def encode(self, x: Tensor) -> (Tensor, Tensor):
|
||||||
@ -44,6 +51,7 @@ class Model:
|
|||||||
return self.fc_mu(x), self.fc_logvar(x)
|
return self.fc_mu(x), self.fc_logvar(x)
|
||||||
|
|
||||||
def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
|
def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
|
||||||
|
log_var = log_var.clip(-10, 10)
|
||||||
std = (log_var * 0.5).exp()
|
std = (log_var * 0.5).exp()
|
||||||
eps = Tensor.randn(mu.shape)
|
eps = Tensor.randn(mu.shape)
|
||||||
return mu + std * eps
|
return mu + std * eps
|
||||||
@ -52,6 +60,6 @@ class Model:
|
|||||||
x = self.dl(x).relu()
|
x = self.dl(x).relu()
|
||||||
x = x.reshape(shape=(-1, 128, self.h, self.w))
|
x = x.reshape(shape=(-1, 128, self.h, self.w))
|
||||||
x = self.d1(x).relu()
|
x = self.d1(x).relu()
|
||||||
x = self.d2(x).sigmoid()
|
x = self.d2(x).relu()
|
||||||
x = self.d3(x)
|
x = self.d3(x).sigmoid()
|
||||||
return x
|
return x
|
||||||
Loading…
x
Reference in New Issue
Block a user