Implimented Transformer(decode only)

This commit is contained in:
k
2026-01-06 21:26:24 -05:00
parent 23f62c7e64
commit 957aad2239

View File

@@ -57,9 +57,18 @@ class Block:
return self
class Transformer():
def __init__(self):
pass #TODO
def __call__(self):
pass #TODO
def cast(self):
pass #TODO
def __init__(self,vocab_size,embed_size,n_heads,n_blocks):
self.tok_embed = nn.Embedding(vocab_size,embed_size)
self.blocks = [Block(embed_size,n_heads) for _ in range(n_blocks)]
self.norm = nn.RMSNorm(embed_size)
self.output = nn.Linear(embed_size,vocab_size,bias=False)
def __call__(self,x):
x = self.tok_embed(x)
x = x.sequential(self.blocks)
x = self.norm(x)
return self.output(x)
def cast(self,dtype):
self.tok_embed.weight = self.tok_embed.weight.cast(dtype)
self.blocks = [b.cast(dtype) for b in self.blocks]
self.output.weight = self.output.weight.cast(dtype)
return self