from tinygrad import Tensor,nn,TinyJit class MultiHeadAttention: def __init__(self,embed_size,n_heads): assert embed_size % n_heads == 0 self.head_size = embed_size//n_heads self.n_heads = n_heads self.qkv = nn.Linear(embed_size, embed_size*3,bias=False) self.projection = nn.Linear(embed_size, embed_size,bias=False) def __call__(self,x): B,T,C=x.shape q,k,v = self.qkv(x).chunk(3,dim=-1) q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) #B H T S #TODO attention free transformer out = q.scaled_dot_product_attention(k,v,is_causal=True,dropout_p=0.01) out = out.transpose(1,2).view(B,T,C) return self.projection(out) def cast(self,dtype): self.qkv.weight = self.qkv.weight.cast(dtype) self.projection.weight = self.projection.weight.cast(dtype) return self class FeedForwardNetwork: def __init__(self,embed_size,ratio=(8/3)): hidden_size = int(embed_size*ratio) self.norm = nn.RMSNorm(embed_size) self.gate = nn.Linear(embed_size,hidden_size,bias=False) self.up = nn.Linear(embed_size, hidden_size,bias=False) self.down = nn.Linear(hidden_size,embed_size,bias=False) def __call__(self,x): x = self.norm(x) return self.down(self.gate(x).silu() * self.up(x)).dropout(0.01) def cast(self,dtype): self.gate.weight = self.gate.weight.cast(dtype) self.up.weight = self.up.weight.cast(dtype) self.down.weight = self.down.weight.cast(dtype) return self class Block: def __init__(self,embed_size,n_heads): self.mha = MultiHeadAttention(embed_size,n_heads) self.ffn = FeedForwardNetwork(embed_size) self.mhaNorm = nn.RMSNorm(embed_size) self.ffnNorm = nn.RMSNorm(embed_size) def __call__(self,x): x = x + self.mha(self.mhaNorm(x)) x = x + self.ffn(self.ffnNorm(x)) return x def cast(self,dtype): self.mha = self.mha.cast(dtype) self.ffn = self.ffn.cast(dtype) return self class Transformer(): def __init__(self,vocab_size,embed_size,n_heads,n_blocks,max_len): self.tok_embed = nn.Embedding(vocab_size,embed_size) self.pos_embed = nn.Embedding(block_size,embed_size) self.pos_idx = Tensor.arange(max_len, requires_grad=False) self.blocks = [Block(embed_size,n_heads) for _ in range(n_blocks)] self.norm = nn.RMSNorm(embed_size) self.output = nn.Linear(embed_size,vocab_size,bias=False) def __call__(self,x): B,T = x.shape pos_embeds = self.pos_embed(self.pos_idx[:T]) x = self.tok_embed(x) + pos_embeds x = x.sequential(self.blocks) x = self.norm(x) return self.output(x) def cast(self,dtype): self.tok_embed.weight = self.tok_embed.weight.cast(dtype) self.blocks = [b.cast(dtype) for b in self.blocks] self.output.weight = self.output.weight.cast(dtype) return self