Implimented TransformerBlock

This commit is contained in:
k
2026-01-06 19:53:37 -05:00
parent 77aa0de0eb
commit 23f62c7e64

View File

@@ -42,12 +42,19 @@ class FeedForwardNetwork:
return self
class Block:
def __init__(self):
pass #TODO
def __call__(self):
pass #TODO
def cast(self):
pass #TODO
def __init__(self,embed_size,n_heads):
self.mha = MultiHeadAttention(embed_size,n_heads)
self.ffn = FeedForwardNetwork(embed_size)
self.mhaNorm = nn.RMSNorm(embed_size)
self.ffnNorm = nn.RMSNorm(embed_size)
def __call__(self,x):
x = x + self.mha(self.mhaNorm(x))
x = x + self.ffn(self.ffnNorm(x))
return x
def cast(self,dtype):
self.mha = self.mha.cast(dtype)
self.ffn = self.ffn.cast(dtype)
return self
class Transformer():
def __init__(self):