From 6e0b3882bcc452b3152f40f898410c26da3fd4dd Mon Sep 17 00:00:00 2001 From: k Date: Wed, 12 Nov 2025 12:13:03 -0500 Subject: [PATCH] added transformer block in latenent space --- model.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/model.py b/model.py index dd18bd2..623d9ec 100644 --- a/model.py +++ b/model.py @@ -10,6 +10,9 @@ class gen: self.h = height // 8 self.flattened_size = 256 * self.h * self.w + self.num_tokens = 16 + self.dim_per_token = self.latent_dim // self.num_tokens + self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1) self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) @@ -18,6 +21,15 @@ class gen: self.el = nn.Linear(self.flattened_size, self.latent_dim) + self.q = nn.Linear(self.dim_per_token,self.dim_per_token) + self.k = nn.Linear(self.dim_per_token,self.dim_per_token) + self.v = nn.Linear(self.dim_per_token,self.dim_per_token) + self.norm1 = nn.LayerNorm(self.dim_per_token) + + ffn_dim = self.dim_per_token * 4 + self.ffn1 = nn.Linear(self.dim_per_token, ffn_dim) + self.ffn2 = nn.Linear(ffn_dim, self.dim_per_token) + self.norm2 = nn.LayerNorm(self.dim_per_token) self.dl = nn.Linear(self.latent_dim, self.flattened_size) @@ -27,7 +39,7 @@ class gen: def __call__(self, x: Tensor) -> Tensor: y, shape = self.encode(x) - z = y#self.atten(y) + z = self.atten(y) return self.decode(z, shape) def encode(self, x: Tensor): @@ -37,19 +49,28 @@ class gen: b, c, h, w = x.shape flattened_size = c * h * w - - x = x.reshape(shape=(b, flattened_size)) z = self.el(x) + + # reshape to multi-token: (batch, num_tokens, dim_per_token) + z = z.reshape(shape=(b, self.num_tokens, self.dim_per_token)) return z, (c, h, w) def atten(self, x: Tensor): - q = self.q(x).relu() - k = self.k(x).relu() - v = self.v(x).relu() - return q.scaled_dot_product_attention(k,v) + q = self.q(x) + k = self.k(x) + v = self.v(x) + attn = q.scaled_dot_product_attention(k, v) + x = self.norm1(x+attn) + + ffn = self.ffn1(x).relu() + ffn = self.ffn2(ffn) + x = self.norm2(x+ffn) + + return x def decode(self, z: Tensor, shape): + z = z.reshape(shape=(z.shape[0], -1)) x = self.dl(z).leakyrelu() x = x.reshape(shape=(-1, 256, self.h, self.w)) x = self.d1(x).leakyrelu()