diff --git a/model.py b/model.py index 8df1c46..e484430 100644 --- a/model.py +++ b/model.py @@ -1,12 +1,13 @@ from tinygrad import Tensor,nn,TinyJit class MultiHeadAttention: - def __init__(self,embed_size,n_heads): + def __init__(self,embed_size,n_heads,lin): assert embed_size % n_heads == 0 self.head_size = embed_size//n_heads self.n_heads = n_heads self.qkv = nn.Linear(embed_size, embed_size*3,bias=False) self.projection = nn.Linear(embed_size, embed_size,bias=False) + self.lin = lin def __call__(self,x): B,T,C=x.shape @@ -15,10 +16,16 @@ class MultiHeadAttention: k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) #B H T S - #TODO attention free transformer - out = q.scaled_dot_product_attention(k,v,is_causal=True,dropout_p=0.01) + out = None + if self.lin: + q = q.sigmoid() + k = k.sigmoid() + out = ((q*k).exp()/(q*k)) * v + else: + out = q.scaled_dot_product_attention(k,v,is_causal=True) out = out.transpose(1,2).view(B,T,C) + return self.projection(out) def cast(self,dtype): self.qkv.weight = self.qkv.weight.cast(dtype) @@ -43,8 +50,8 @@ class FeedForwardNetwork: return self class Block: - def __init__(self,embed_size,n_heads): - self.mha = MultiHeadAttention(embed_size,n_heads) + def __init__(self,embed_size,n_heads,lin): + self.mha = MultiHeadAttention(embed_size,n_heads,lin) self.ffn = FeedForwardNetwork(embed_size) self.mhaNorm = nn.RMSNorm(embed_size) self.ffnNorm = nn.RMSNorm(embed_size) @@ -61,9 +68,9 @@ class Transformer(): def __init__(self,vocab_size,embed_size,n_heads,n_blocks,block_size): self.tok_embed = nn.Embedding(vocab_size,embed_size) self.pos_embed = nn.Embedding(block_size,embed_size) - self.pos_idx = Tensor.arange(block_size, requires_grad=False) + self.pos_idx = Tensor.arange(block_size, requires_grad=False).sin() - self.blocks = [Block(embed_size,n_heads) for _ in range(n_blocks)] + self.blocks = [Block(embed_size,n_heads,i%4==0) for i in range(n_blocks)] self.norm = nn.RMSNorm(embed_size) self.output = nn.Linear(embed_size,vocab_size,bias=False) def __call__(self,x):