from tinygrad import Tensor, nn import numpy as np class Gen: def __init__(self, input_channels=1, height=128, width=216, latent_dim=32): self.w = width // 8 self.h = height // 8 self.flattened_size = 256 * self.h * self.w # Encoder self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1) self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.e3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) # VAE Latent Space self.fc_mu = nn.Linear(self.flattened_size, latent_dim) self.fc_logvar = nn.Linear(self.flattened_size, latent_dim) # Decoder self.dl = nn.Linear(latent_dim, self.flattened_size) self.d1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1,output_padding=1) self.d2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1) self.d3 = nn.ConvTranspose2d(64, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1) def __call__(self, x: Tensor) -> Tensor: mu, log_var = self.encode(x) x = self.reparameterize(mu, log_var) return self.decode(x) def __Lcall__(self, inp: Tensor, otp:Tensor, epoch) -> (Tensor, Tensor): mu, log_var = self.encode(inp) z = self.reparameterize(mu, log_var) recon = self.decode(z) # Normalized MSE (per-pixel) recon_loss = (recon - otp).abs().mean() # Stabilized KL kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean() # Weighted loss total_loss = recon_loss + min(0.1, 0.01 * epoch) * kl_div return recon, total_loss def encode(self, x: Tensor) -> (Tensor, Tensor): x = self.e1(x).relu() x = self.e2(x).relu() x = self.e3(x).relu() x = x.reshape(shape=(-1, self.flattened_size)) return self.fc_mu(x), self.fc_logvar(x) def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor: log_var = log_var.clip(-10, 10) std = (log_var * 0.5).exp() eps = Tensor.randn(mu.shape) return mu + std * eps def decode(self, x: Tensor) -> Tensor: x = self.dl(x).relu() x = x.reshape(shape=(-1, 256, self.h, self.w)) x = self.d1(x).relu() x = self.d2(x).relu() x = self.d3(x).sigmoid() return x class Check(): def __init__(self, input_channels=1, height=128, width=216): self.w = width // 8 self.h = height // 8 self.flattened_size = 256 * self.h * self.w self.d1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1) self.d2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.d3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) self.fc = nn.Linear(self.flattened_size, 1) def __call__(self, x: Tensor) -> Tensor: x = self.d1(x).leakyrelu(0.2) x = self.d2(x).leakyrelu(0.2) x = self.d3(x).leakyrelu(0.2) x = x.reshape(shape=(-1, self.flattened_size)) return self.fc(x)