From 91faa18011ac2b6b7a84e1f61416d8826d031757 Mon Sep 17 00:00:00 2001 From: k Date: Fri, 25 Jul 2025 18:34:58 -0400 Subject: [PATCH] created model file --- model.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 model.py diff --git a/model.py b/model.py new file mode 100644 index 0000000..06488f4 --- /dev/null +++ b/model.py @@ -0,0 +1,57 @@ +from tinygrad import Tensor, nn + +class Model: + def __init__(self, input_channels=1, height=128, width=216, latent_dim=32): + self.w = width // 8 + self.h = height // 8 + self.flattened_size = 128 * self.h * self.w + + # Encoder + self.e1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1) + self.e2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) + self.e3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) + + # VAE Latent Space + self.fc_mu = nn.Linear(self.flattened_size, latent_dim) + self.fc_logvar = nn.Linear(self.flattened_size, latent_dim) + + # Decoder + self.dl = nn.Linear(latent_dim, self.flattened_size) + self.d1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1) + self.d2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1,output_padding=1) + self.d3 = nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1) + + def __call__(self, x: Tensor) -> Tensor: + mu, log_var = self.encode(x) + x = self.reparameterize(mu, log_var) + return self.decode(x) + + def __Lcall__(self, inp: Tensor, otp) -> (Tensor, Tensor): + mu, log_var = self.encode(inp) + z = self.reparameterize(mu, log_var) + recon = self.decode(z) + + recon_loss = (otp - recon).pow(2).sum() + kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum() + total_loss = recon_loss + kl_div + return recon, total_loss + + def encode(self, x: Tensor) -> (Tensor, Tensor): + x = self.e1(x).relu() + x = self.e2(x).relu() + x = self.e3(x).relu() + x = x.reshape(shape=(-1, self.flattened_size)) + return self.fc_mu(x), self.fc_logvar(x) + + def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor: + std = (log_var * 0.5).exp() + eps = Tensor.randn(mu.shape) + return mu + std * eps + + def decode(self, x: Tensor) -> Tensor: + x = self.dl(x).relu() + x = x.reshape(shape=(-1, 128, self.h, self.w)) + x = self.d1(x).relu() + x = self.d2(x).sigmoid() + x = self.d3(x) + return x \ No newline at end of file