diff --git a/train.py b/train.py index c0c7f3e..80cfb27 100644 --- a/train.py +++ b/train.py @@ -12,20 +12,20 @@ import sys hypr = { "embed_size": 768, - "n_heads": 12, + "n_heads": 8, "n_blocks": 12, "block_size": 512, "batch_size": 8, "starting_lr": 6e-4, "minimum_lr": 6e-5, - "warmup": 1_000, - "steps": 20_000, - "encoding": "gpt2", + "warmup": 5_000, + "steps": 535_000, + "encoding": "TinyLlama/TinyLlama_v1.1", "dataset": "HuggingFaceTB/smollm-corpus", "subset": "cosmopedia-v2", "chat_dataset": "HuggingFaceTB/smoltalk", "chat_subset": "all", - "half": False, + "half": True, } print(Device.DEFAULT)