made model biger
This commit is contained in:
parent
a28901fdfd
commit
099400e1da
26
model.py
26
model.py
@ -8,9 +8,9 @@ class Model:
|
||||
self.flattened_size = 128 * self.h * self.w
|
||||
|
||||
# Encoder
|
||||
self.e1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
|
||||
self.e2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
|
||||
self.e3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||
self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
|
||||
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||
self.e3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# VAE Latent Space
|
||||
self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
|
||||
@ -18,29 +18,29 @@ class Model:
|
||||
|
||||
# Decoder
|
||||
self.dl = nn.Linear(latent_dim, self.flattened_size)
|
||||
self.d1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1)
|
||||
self.d2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1,output_padding=1)
|
||||
self.d3 = nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1)
|
||||
self.d1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1,output_padding=1)
|
||||
self.d2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1,output_padding=1)
|
||||
self.d3 = nn.ConvTranspose2d(64, input_channels, kernel_size=3, stride=2, padding=1,output_padding=1)
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
mu, log_var = self.encode(x)
|
||||
x = self.reparameterize(mu, log_var)
|
||||
return self.decode(x)
|
||||
|
||||
def __Lcall__(self, inp: Tensor, otp) -> (Tensor, Tensor):
|
||||
def __Lcall__(self, inp: Tensor, otp:Tensor, epoch) -> (Tensor, Tensor):
|
||||
mu, log_var = self.encode(inp)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
recon = self.decode(z)
|
||||
|
||||
# Normalized MSE (per-pixel)
|
||||
recon_loss = ((otp - recon).pow(2).mean())
|
||||
recon_loss = (recon - otp).abs().mean()
|
||||
|
||||
# Stabilized KL
|
||||
kl_div = -0.5 * (1 + log_var.clip(-10, 10) - mu.pow(2) - log_var.clip(-10, 10).exp()).mean()
|
||||
kl_div = kl_div.relu()
|
||||
kl_div = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean()
|
||||
|
||||
|
||||
# Weighted loss
|
||||
total_loss = recon_loss + 0.5 * kl_div
|
||||
total_loss = recon_loss + min(0.1, 0.01 * epoch) * kl_div
|
||||
return recon, total_loss
|
||||
|
||||
def encode(self, x: Tensor) -> (Tensor, Tensor):
|
||||
@ -58,8 +58,8 @@ class Model:
|
||||
|
||||
def decode(self, x: Tensor) -> Tensor:
|
||||
x = self.dl(x).relu()
|
||||
x = x.reshape(shape=(-1, 128, self.h, self.w))
|
||||
x = x.reshape(shape=(-1, 256, self.h, self.w))
|
||||
x = self.d1(x).relu()
|
||||
x = self.d2(x).relu()
|
||||
x = self.d3(x).sigmoid()
|
||||
x = self.d3(x)
|
||||
return x
|
||||
Loading…
x
Reference in New Issue
Block a user