from tinygrad import Tensor, nn class gen: def __init__(self, input_channels=1, height=128, width=216, latent_dim=1024): self.height = height self.width = width self.latent_dim = latent_dim self.w = width // 8 self.h = height // 8 self.flattened_size = 256 * self.h * self.w self.num_tokens = 16 self.dim_per_token = self.latent_dim // self.num_tokens self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1) self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.e3 = nn.Conv2d(128,256, kernel_size=3,stride=2,padding=1) self.el = nn.Linear(self.flattened_size, self.latent_dim) self.q = nn.Linear(self.dim_per_token,self.dim_per_token) self.k = nn.Linear(self.dim_per_token,self.dim_per_token) self.v = nn.Linear(self.dim_per_token,self.dim_per_token) self.norm1 = nn.LayerNorm(self.dim_per_token) ffn_dim = self.dim_per_token * 4 self.ffn1 = nn.Linear(self.dim_per_token, ffn_dim) self.ffn2 = nn.Linear(ffn_dim, self.dim_per_token) self.norm2 = nn.LayerNorm(self.dim_per_token) self.dl = nn.Linear(self.latent_dim, self.flattened_size) self.d1 = nn.ConvTranspose2d(256,128,kernel_size=3,stride=2,padding=1,output_padding=1) self.d2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) self.d3 = nn.ConvTranspose2d(64, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1) def __call__(self, x: Tensor) -> Tensor: y, shape = self.encode(x) z = self.atten(y) return self.decode(z, shape) def encode(self, x: Tensor): x = self.e1(x).leakyrelu() x = self.e2(x).leakyrelu() x = self.e3(x).leakyrelu() b, c, h, w = x.shape flattened_size = c * h * w x = x.reshape(shape=(b, flattened_size)) z = self.el(x) # reshape to multi-token: (batch, num_tokens, dim_per_token) z = z.reshape(shape=(b, self.num_tokens, self.dim_per_token)) return z, (c, h, w) def atten(self, x: Tensor): q = self.q(x) k = self.k(x) v = self.v(x) attn = q.scaled_dot_product_attention(k, v) x = self.norm1(x+attn) ffn = self.ffn1(x).relu() ffn = self.ffn2(ffn) x = self.norm2(x+ffn) return x def decode(self, z: Tensor, shape): z = z.reshape(shape=(z.shape[0], -1)) x = self.dl(z).leakyrelu() x = x.reshape(shape=(-1, 256, self.h, self.w)) x = self.d1(x).leakyrelu() x = self.d2(x).leakyrelu() x = self.d3(x).sigmoid() # Crop or pad to match input size out_h, out_w = x.shape[2], x.shape[3] if out_h > self.height: x = x[:, :, :self.height, :] elif out_h < self.height: pad_h = self.height - out_h x = x.pad2d((0, 0, 0, pad_h)) if out_w > self.width: x = x[:, :, :, :self.width] elif out_w < self.width: pad_w = self.width - out_w x = x.pad2d((0, pad_w, 0, 0)) return x