{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import data\n", "import show\n", "import model" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "x,y = data.detaset(data.load())\n", "len(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sample=x[420]\n", "show.showSpec(sample)\n", "show.playSpec(sample)" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "from tinygrad import nn\n", "gen = model.Gen()\n", "optimizer = nn.optim.AdamW(nn.state.get_parameters(gen), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "@TinyJit\n", "def jit_step(X: Tensor, Y: Tensor,epoch) -> Tensor:\n", " Tensor.training = True\n", " optimizer.zero_grad()\n", " _, loss = gen.__Lcall__(X,Y,epoch)\n", " loss.backward()\n", " optimizer.step()\n", " return loss.realize()" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "x_np, y_np = np.stack(x), np.stack(y)\n", "x_np = np.expand_dims(x_np, axis=1)\n", "y_np = np.expand_dims(y_np, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "size=len(x)\n", "BACH_SIZE=32\n", "eshape = (BACH_SIZE, 1, 128, 216)\n", "\n", "for e in range(0,12):\n", " print(f\"\\n--- Starting Epoch {e} ---\\n\")\n", " l=0\n", " \n", " permutation = np.random.permutation(size)\n", " x_np = x_np[permutation]\n", " y_np = y_np[permutation]\n", " \n", " for i in range(0,size,BACH_SIZE):\n", " tx,ty=Tensor(x_np[i:i+BACH_SIZE]),Tensor(y_np[i:i+BACH_SIZE])\n", " if(tx.shape != eshape or ty.shape != eshape):\n", " continue\n", " l+=jit_step(tx,ty,e).numpy()\n", " \n", " l /= (size/BACH_SIZE)\n", " print(f\"loss of {l}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "from tinygrad.nn.state import safe_save, get_state_dict\n", "safe_save(get_state_dict(gen),\"music.safetensors\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }