Fix errors

This commit is contained in:
k
2026-01-07 02:13:08 -05:00
parent 007c96e91b
commit 7f25dff1d1
3 changed files with 20 additions and 17 deletions

View File

@@ -58,10 +58,10 @@ class Block:
return self
class Transformer():
def __init__(self,vocab_size,embed_size,n_heads,n_blocks,max_len):
def __init__(self,vocab_size,embed_size,n_heads,n_blocks,block_size):
self.tok_embed = nn.Embedding(vocab_size,embed_size)
self.pos_embed = nn.Embedding(block_size,embed_size)
self.pos_idx = Tensor.arange(max_len, requires_grad=False)
self.pos_idx = Tensor.arange(block_size, requires_grad=False)
self.blocks = [Block(embed_size,n_heads) for _ in range(n_blocks)]
self.norm = nn.RMSNorm(embed_size)