added Positional encodeings

This commit is contained in:
k
2026-01-06 21:38:12 -05:00
parent 3b590b3ce7
commit 478010c8cc

View File

@@ -58,13 +58,18 @@ class Block:
return self
class Transformer():
def __init__(self,vocab_size,embed_size,n_heads,n_blocks):
def __init__(self,vocab_size,embed_size,n_heads,n_blocks,max_len):
self.tok_embed = nn.Embedding(vocab_size,embed_size)
self.pos_embed = nn.Embedding(block_size,embed_size)
self.pos_idx = Tensor.arange(max_len, requires_grad=False)
self.blocks = [Block(embed_size,n_heads) for _ in range(n_blocks)]
self.norm = nn.RMSNorm(embed_size)
self.output = nn.Linear(embed_size,vocab_size,bias=False)
def __call__(self,x):
x = self.tok_embed(x)
B,T = x.shape
pos_embeds = self.pos_embed(self.pos_idx[:T])
x = self.tok_embed(x) + pos_embeds
x = x.sequential(self.blocks)
x = self.norm(x)
return self.output(x)