from tinygrad import Tensor, nn class Model: def __init__(self, input_channels=1, height=128, width=216, latent_dim=32): self.w = width // 8 self.h = height // 8 self.flattened_size = 128 * self.h * self.w # Encoder self.e1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1) self.e2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) self.e3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # VAE Latent Space self.fc_mu = nn.Linear(self.flattened_size, latent_dim) self.fc_logvar = nn.Linear(self.flattened_size, latent_dim) # Decoder self.dl = nn.Linear(latent_dim, self.flattened_size) self.d1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1) self.d2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1,output_padding=1) self.d3 = nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1) def __call__(self, x: Tensor) -> Tensor: mu, log_var = self.encode(x) x = self.reparameterize(mu, log_var) return self.decode(x) def __Lcall__(self, inp: Tensor, otp) -> (Tensor, Tensor): mu, log_var = self.encode(inp) z = self.reparameterize(mu, log_var) recon = self.decode(z) recon_loss = (otp - recon).pow(2).sum() kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum() total_loss = recon_loss + kl_div return recon, total_loss def encode(self, x: Tensor) -> (Tensor, Tensor): x = self.e1(x).relu() x = self.e2(x).relu() x = self.e3(x).relu() x = x.reshape(shape=(-1, self.flattened_size)) return self.fc_mu(x), self.fc_logvar(x) def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor: std = (log_var * 0.5).exp() eps = Tensor.randn(mu.shape) return mu + std * eps def decode(self, x: Tensor) -> Tensor: x = self.dl(x).relu() x = x.reshape(shape=(-1, 128, self.h, self.w)) x = self.d1(x).relu() x = self.d2(x).sigmoid() x = self.d3(x) return x