diff --git a/model.py b/model.py index 22c13db..b4e3857 100644 --- a/model.py +++ b/model.py @@ -8,9 +8,9 @@ class Model: self.flattened_size = 128 * self.h * self.w # Encoder - self.e1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1) - self.e2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) - self.e3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) + self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1) + self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) + self.e3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) # VAE Latent Space self.fc_mu = nn.Linear(self.flattened_size, latent_dim) @@ -18,29 +18,29 @@ class Model: # Decoder self.dl = nn.Linear(latent_dim, self.flattened_size) - self.d1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1) - self.d2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1,output_padding=1) - self.d3 = nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1) + self.d1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1,output_padding=1) + self.d2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1) + self.d3 = nn.ConvTranspose2d(64, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1) def __call__(self, x: Tensor) -> Tensor: mu, log_var = self.encode(x) x = self.reparameterize(mu, log_var) return self.decode(x) - def __Lcall__(self, inp: Tensor, otp) -> (Tensor, Tensor): + def __Lcall__(self, inp: Tensor, otp:Tensor, epoch) -> (Tensor, Tensor): mu, log_var = self.encode(inp) z = self.reparameterize(mu, log_var) recon = self.decode(z) # Normalized MSE (per-pixel) - recon_loss = ((otp - recon).pow(2).mean()) + recon_loss = (recon - otp).abs().mean() # Stabilized KL - kl_div = -0.5 * (1 + log_var.clip(-10, 10) - mu.pow(2) - log_var.clip(-10, 10).exp()).mean() - kl_div = kl_div.relu() + kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean() + # Weighted loss - total_loss = recon_loss + 0.5 * kl_div + total_loss = recon_loss + min(0.1, 0.01 * epoch) * kl_div return recon, total_loss def encode(self, x: Tensor) -> (Tensor, Tensor): @@ -58,8 +58,8 @@ class Model: def decode(self, x: Tensor) -> Tensor: x = self.dl(x).relu() - x = x.reshape(shape=(-1, 128, self.h, self.w)) + x = x.reshape(shape=(-1, 256, self.h, self.w)) x = self.d1(x).relu() x = self.d2(x).relu() - x = self.d3(x).sigmoid() + x = self.d3(x) return x \ No newline at end of file