InfiniteMusic/model.py
2025-11-10 22:34:17 -05:00

74 lines
2.3 KiB
Python

from tinygrad import Tensor, nn
class gen:
def __init__(self, input_channels=1, height=128, width=431, latent_dim=64):
self.height = height
self.width = width
self.latent_dim = latent_dim
self.w = width // 4
self.h = height // 4
self.h = 32 # Output height after 2 strides
self.w = 108 # Output width after 2 strides
self.flattened_size = 128 * self.h * self.w
self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.el = nn.Linear(self.flattened_size, self.latent_dim)
self.q = nn.Linear(self.latent_dim,self.latent_dim)
self.k = nn.Linear(self.latent_dim,self.latent_dim)
self.v = nn.Linear(self.latent_dim,self.latent_dim)
self.dl = nn.Linear(self.latent_dim, self.flattened_size)
self.d1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.d2 = nn.ConvTranspose2d(64, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
def __call__(self, x: Tensor) -> Tensor:
y, shape = self.encode(x)
z = self.atten(y)
return self.decode(z, shape)
def encode(self, x: Tensor):
x = self.e1(x).leakyrelu()
x = self.e2(x).leakyrelu()
b, c, h, w = x.shape
flattened_size = c * h * w
x = x.reshape(shape=(b, flattened_size))
z = self.el(x)
return z, (c, h, w)
def atten(self, x: Tensor):
q = self.q(x).relu()
k = self.k(x).relu()
v = self.v(x).relu()
return q.scaled_dot_product_attention(k,v)
def decode(self, z: Tensor, shape):
x = self.dl(z).leakyrelu()
x = x.reshape(shape=(-1, 128, self.h, self.w))
x = self.d1(x).leakyrelu()
x = self.d2(x).sigmoid()
# Crop or pad to match input size
out_h, out_w = x.shape[2], x.shape[3]
if out_h > self.height:
x = x[:, :, :self.height, :]
elif out_h < self.height:
pad_h = self.height - out_h
x = x.pad2d((0, 0, 0, pad_h))
if out_w > self.width:
x = x[:, :, :, :self.width]
elif out_w < self.width:
pad_w = self.width - out_w
x = x.pad2d((0, pad_w, 0, 0))
return x