fixed cast in ffn
This commit is contained in:
19
model.py
19
model.py
@@ -10,19 +10,20 @@ class MultiHeadAttention:
|
|||||||
|
|
||||||
|
|
||||||
class FeedForwardNetwork:
|
class FeedForwardNetwork:
|
||||||
def __init__(self,embeding_size,ratio=(8/3)):
|
def __init__(self,embed_size,ratio=(8/3)):
|
||||||
hidden_size = int(embeding_size*ratio)
|
hidden_size = int(embed_size*ratio)
|
||||||
self.norm = nn.RMSNorm(embeding_size)
|
self.norm = nn.RMSNorm(embed_size)
|
||||||
self.gate = nn.Linear(embeding_size,hidden_size,bias=False)
|
self.gate = nn.Linear(embed_size,hidden_size,bias=False)
|
||||||
self.up = nn.Linear(embeding_size, hidden_size,bias=False)
|
self.up = nn.Linear(embed_size, hidden_size,bias=False)
|
||||||
self.down = nn.Linear(hidden_size,embeding_size,bias=False)
|
self.down = nn.Linear(hidden_size,embed_size,bias=False)
|
||||||
def __call__(self,x):
|
def __call__(self,x):
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return self.down(self.gate(x).silu() * self.up(x))
|
return self.down(self.gate(x).silu() * self.up(x))
|
||||||
def cast(self,dtype):
|
def cast(self,dtype):
|
||||||
self.gate.weight = gate.weight.cast(dtype)
|
self.gate.weight = self.gate.weight.cast(dtype)
|
||||||
self.up.weight = up.weight.cast(dtype)
|
self.up.weight = self.up.weight.cast(dtype)
|
||||||
self.down.weight = down.weight.cast(dtype)
|
self.down.weight = self.down.weight.cast(dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
class Block:
|
class Block:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user