added dropout to ffn
This commit is contained in:
parent
957aad2239
commit
3b590b3ce7
1 changed files with 2 additions and 1 deletions
3
model.py
3
model.py
|
|
@ -15,6 +15,7 @@ class MultiHeadAttention:
|
||||||
k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
||||||
v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
||||||
#B H T S
|
#B H T S
|
||||||
|
#TODO attention free transformer
|
||||||
|
|
||||||
out = q.scaled_dot_product_attention(k,v,is_causal=True,dropout_p=0.01)
|
out = q.scaled_dot_product_attention(k,v,is_causal=True,dropout_p=0.01)
|
||||||
out = out.transpose(1,2).view(B,T,C)
|
out = out.transpose(1,2).view(B,T,C)
|
||||||
|
|
@ -34,7 +35,7 @@ class FeedForwardNetwork:
|
||||||
self.down = nn.Linear(hidden_size,embed_size,bias=False)
|
self.down = nn.Linear(hidden_size,embed_size,bias=False)
|
||||||
def __call__(self,x):
|
def __call__(self,x):
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return self.down(self.gate(x).silu() * self.up(x))
|
return self.down(self.gate(x).silu() * self.up(x)).dropout(0.01)
|
||||||
def cast(self,dtype):
|
def cast(self,dtype):
|
||||||
self.gate.weight = self.gate.weight.cast(dtype)
|
self.gate.weight = self.gate.weight.cast(dtype)
|
||||||
self.up.weight = self.up.weight.cast(dtype)
|
self.up.weight = self.up.weight.cast(dtype)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue