Implimented TransformerBlock
This commit is contained in:
parent
77aa0de0eb
commit
23f62c7e64
1 changed files with 13 additions and 6 deletions
19
model.py
19
model.py
|
|
@ -42,12 +42,19 @@ class FeedForwardNetwork:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
class Block:
|
class Block:
|
||||||
def __init__(self):
|
def __init__(self,embed_size,n_heads):
|
||||||
pass #TODO
|
self.mha = MultiHeadAttention(embed_size,n_heads)
|
||||||
def __call__(self):
|
self.ffn = FeedForwardNetwork(embed_size)
|
||||||
pass #TODO
|
self.mhaNorm = nn.RMSNorm(embed_size)
|
||||||
def cast(self):
|
self.ffnNorm = nn.RMSNorm(embed_size)
|
||||||
pass #TODO
|
def __call__(self,x):
|
||||||
|
x = x + self.mha(self.mhaNorm(x))
|
||||||
|
x = x + self.ffn(self.ffnNorm(x))
|
||||||
|
return x
|
||||||
|
def cast(self,dtype):
|
||||||
|
self.mha = self.mha.cast(dtype)
|
||||||
|
self.ffn = self.ffn.cast(dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
class Transformer():
|
class Transformer():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue