added transformer block in latenent space
This commit is contained in:
parent
b076a0d123
commit
6e0b3882bc
35
model.py
35
model.py
@ -10,6 +10,9 @@ class gen:
|
|||||||
self.h = height // 8
|
self.h = height // 8
|
||||||
self.flattened_size = 256 * self.h * self.w
|
self.flattened_size = 256 * self.h * self.w
|
||||||
|
|
||||||
|
self.num_tokens = 16
|
||||||
|
self.dim_per_token = self.latent_dim // self.num_tokens
|
||||||
|
|
||||||
|
|
||||||
self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
|
self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
|
||||||
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||||
@ -18,6 +21,15 @@ class gen:
|
|||||||
|
|
||||||
self.el = nn.Linear(self.flattened_size, self.latent_dim)
|
self.el = nn.Linear(self.flattened_size, self.latent_dim)
|
||||||
|
|
||||||
|
self.q = nn.Linear(self.dim_per_token,self.dim_per_token)
|
||||||
|
self.k = nn.Linear(self.dim_per_token,self.dim_per_token)
|
||||||
|
self.v = nn.Linear(self.dim_per_token,self.dim_per_token)
|
||||||
|
self.norm1 = nn.LayerNorm(self.dim_per_token)
|
||||||
|
|
||||||
|
ffn_dim = self.dim_per_token * 4
|
||||||
|
self.ffn1 = nn.Linear(self.dim_per_token, ffn_dim)
|
||||||
|
self.ffn2 = nn.Linear(ffn_dim, self.dim_per_token)
|
||||||
|
self.norm2 = nn.LayerNorm(self.dim_per_token)
|
||||||
|
|
||||||
self.dl = nn.Linear(self.latent_dim, self.flattened_size)
|
self.dl = nn.Linear(self.latent_dim, self.flattened_size)
|
||||||
|
|
||||||
@ -27,7 +39,7 @@ class gen:
|
|||||||
|
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
y, shape = self.encode(x)
|
y, shape = self.encode(x)
|
||||||
z = y#self.atten(y)
|
z = self.atten(y)
|
||||||
return self.decode(z, shape)
|
return self.decode(z, shape)
|
||||||
|
|
||||||
def encode(self, x: Tensor):
|
def encode(self, x: Tensor):
|
||||||
@ -37,19 +49,28 @@ class gen:
|
|||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
flattened_size = c * h * w
|
flattened_size = c * h * w
|
||||||
|
|
||||||
|
|
||||||
x = x.reshape(shape=(b, flattened_size))
|
x = x.reshape(shape=(b, flattened_size))
|
||||||
z = self.el(x)
|
z = self.el(x)
|
||||||
|
|
||||||
|
# reshape to multi-token: (batch, num_tokens, dim_per_token)
|
||||||
|
z = z.reshape(shape=(b, self.num_tokens, self.dim_per_token))
|
||||||
return z, (c, h, w)
|
return z, (c, h, w)
|
||||||
|
|
||||||
def atten(self, x: Tensor):
|
def atten(self, x: Tensor):
|
||||||
q = self.q(x).relu()
|
q = self.q(x)
|
||||||
k = self.k(x).relu()
|
k = self.k(x)
|
||||||
v = self.v(x).relu()
|
v = self.v(x)
|
||||||
return q.scaled_dot_product_attention(k,v)
|
attn = q.scaled_dot_product_attention(k, v)
|
||||||
|
x = self.norm1(x+attn)
|
||||||
|
|
||||||
|
ffn = self.ffn1(x).relu()
|
||||||
|
ffn = self.ffn2(ffn)
|
||||||
|
x = self.norm2(x+ffn)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
def decode(self, z: Tensor, shape):
|
def decode(self, z: Tensor, shape):
|
||||||
|
z = z.reshape(shape=(z.shape[0], -1))
|
||||||
x = self.dl(z).leakyrelu()
|
x = self.dl(z).leakyrelu()
|
||||||
x = x.reshape(shape=(-1, 256, self.h, self.w))
|
x = x.reshape(shape=(-1, 256, self.h, self.w))
|
||||||
x = self.d1(x).leakyrelu()
|
x = self.d1(x).leakyrelu()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user