{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7a4b3fe0-37b7-4dcc-928e-5d5981eb62bd", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "%load_ext tensorboard\n", "import librosa\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "import IPython.display as ipd" ] }, { "cell_type": "markdown", "id": "25beb189-d3c5-4ba2-a31d-f692685291fc", "metadata": {}, "source": [ "# prep" ] }, { "cell_type": "code", "execution_count": 2, "id": "8e45206f-7e1f-4a47-852f-39849a395a53", "metadata": {}, "outputs": [], "source": [ "SAMPLE_RATE = 22050" ] }, { "cell_type": "code", "execution_count": 3, "id": "dc03a57f-8739-4f5e-a1c9-3db7a48e26de", "metadata": {}, "outputs": [], "source": [ "def load():\n", " \"\"\"\n", " Load 10 second chunks of songs\n", " \"\"\"\n", "\n", " files = list(Path(\"./data/\").glob(\"*.mp3\"))\n", " chunks = []\n", " for file in files:\n", " y, sr = librosa.load(file, mono=True,sr=SAMPLE_RATE)\n", " size = int(SAMPLE_RATE * 10)\n", " sampleLen = len(y)\n", " for startPos in range(0,sampleLen,size):\n", " end = startPos+size\n", " if end > sampleLen:\n", " pass\n", " chunk = y[startPos:end]\n", " chunks.append(chunk)\n", " return chunks" ] }, { "cell_type": "code", "execution_count": 4, "id": "f0f0c6d8-4cbf-46d3-a65d-28396fffc650", "metadata": {}, "outputs": [], "source": [ "def dataset(chunks):\n", " x,y = [],[]\n", " size = int(SAMPLE_RATE*5)\n", " for chunk in chunks:\n", " Ax = chunk[:size]\n", " Ay = chunk[size:size*2]\n", " if(len(Ax) == size and len(Ay) == size):\n", " x.append(librosa.feature.melspectrogram(y=Ax, sr=SAMPLE_RATE))\n", " y.append(librosa.feature.melspectrogram(y=Ay, sr=SAMPLE_RATE))\n", " return x,y" ] }, { "cell_type": "code", "execution_count": 5, "id": "d953fafa-b119-4aa4-b17b-8606c0b366b1", "metadata": {}, "outputs": [], "source": [ "def showSpec(spec):\n", " plt.figure(figsize=(10, 4))\n", " librosa.display.specshow(spec, sr=SAMPLE_RATE,\n", " x_axis='time', y_axis='mel',\n", " cmap='viridis')\n", " plt.colorbar(format='%+2.0f dB')\n", " plt.title('Mel spectrogram')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "id": "5a744602-e8ea-4c79-883a-1472c15df3ae", "metadata": {}, "outputs": [], "source": [ "def playSpec(spec):\n", " S = librosa.feature.inverse.mel_to_stft(spec, sr=SAMPLE_RATE)\n", " audio = librosa.griffinlim(S,n_iter=25,momentum=0.99)\n", " audio = librosa.effects.preemphasis(audio) # Reapply pre-emphasis\n", "\n", " plt.figure(figsize=(12,4))\n", " plt.plot(audio)\n", " plt.title('waveform')\n", " plt.show()\n", "\n", " display(ipd.Audio(audio,rate=SAMPLE_RATE))" ] }, { "cell_type": "markdown", "id": "2488eaeb-a378-42dd-bd4a-b9290c445026", "metadata": {}, "source": [ "# Load Data" ] }, { "cell_type": "code", "execution_count": 7, "id": "f411e54e-2a7d-4dfe-be90-457e2a9455a7", "metadata": {}, "outputs": [], "source": [ "chunks = load()" ] }, { "cell_type": "code", "execution_count": 8, "id": "1c942a3f-8072-41b3-bb16-1ccb9812b505", "metadata": {}, "outputs": [], "source": [ "x,y = dataset(chunks)" ] }, { "cell_type": "code", "execution_count": 9, "id": "39241f9b-8cc0-40cf-96dc-1be6d3ff2fb7", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#verify loaded data\n", "sample=x[420]\n", "showSpec(sample)\n", "playSpec(sample)" ] }, { "cell_type": "code", "execution_count": 10, "id": "584543d0-89f7-4ae8-aedb-a01e0e110784", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "216" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(x[0][0])" ] }, { "cell_type": "code", "execution_count": 11, "id": "03accbb9-48b7-40d9-85be-9ed6a92d2e86", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "from tinygrad import TinyJit, Device, Tensor, nn\n", "from tinygrad.nn.state import safe_save, get_state_dict\n", "from model import Model\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "9aa99587-01e0-4156-84a1-15ba0c929aec", "metadata": {}, "outputs": [], "source": [ "x_np = np.array(x)[:, np.newaxis, :, :] # Shape: (N, 1, 128, 216)\n", "y_np = np.array(y)[:, np.newaxis, :, :] # Shape: (N, 1, 128, 216)\n", "\n", "# Training parameters\n", "num_epochs = 15\n", "batch_size = 20\n", "num_samples = len(x_np)\n", "num_batches = num_samples // batch_size" ] }, { "cell_type": "code", "execution_count": 13, "id": "bcf1bec0-be3f-411b-8324-7877f0ead016", "metadata": {}, "outputs": [], "source": [ "model = Model()\n", "optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 24, "id": "ea3a98ca-756b-4658-9c4a-3cbc36c61cdc", "metadata": {}, "outputs": [], "source": [ "@TinyJit\n", "def jit_step(X: Tensor, Y: Tensor,show) -> Tensor:\n", " Tensor.training = True\n", " optimizer.zero_grad()\n", " sample, loss = model.__Lcall__(X,Y)\n", " loss.backward()\n", " optimizer.step()\n", " return loss.realize()\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "bc6dfda7-4249-4e22-a931-85deef0e5347", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "ename": "ParameterError", "evalue": "Audio buffer is not finite everywhere", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mParameterError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[25]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 8\u001b[39m batch_x = Tensor(x_np[indices[start:end]])\n\u001b[32m 9\u001b[39m batch_y = Tensor(y_np[indices[start:end]])\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m loss = \u001b[43mjit_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_y\u001b[49m\u001b[43m,\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m \u001b[49m\u001b[43m%\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m4\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m==\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 12\u001b[39m epoch_loss += loss.item()\n\u001b[32m 14\u001b[39m avg_epoch_loss = epoch_loss / num_batches\n", "\u001b[36mFile \u001b[39m\u001b[32m/nix/store/khnvx4lwxjcrq6n0kllvbry5q64v8dcz-python3.12-tinygrad-0.10.2/lib/python3.12/site-packages/tinygrad/engine/jit.py:250\u001b[39m, in \u001b[36mTinyJit.__call__\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 248\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m.fxn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m Context(BEAM=\u001b[32m0\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m getenv(\u001b[33m\"\u001b[39m\u001b[33mIGNORE_JIT_FIRST_BEAM\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m BEAM.value):\n\u001b[32m--> \u001b[39m\u001b[32m250\u001b[39m ret = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mfxn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 251\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(params:=get_parameters(ret)): Tensor.realize(params[\u001b[32m0\u001b[39m], *params[\u001b[32m1\u001b[39m:])\n\u001b[32m 252\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.cnt == \u001b[32m1\u001b[39m:\n\u001b[32m 253\u001b[39m \u001b[38;5;66;03m# jit capture\u001b[39;00m\n", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 9\u001b[39m, in \u001b[36mjit_step\u001b[39m\u001b[34m(X, Y, show)\u001b[39m\n\u001b[32m 7\u001b[39m optimizer.step()\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m show:\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[43mplaySpec\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m loss.realize()\n", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 3\u001b[39m, in \u001b[36mplaySpec\u001b[39m\u001b[34m(spec)\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mplaySpec\u001b[39m(spec):\n\u001b[32m 2\u001b[39m S = librosa.feature.inverse.mel_to_stft(spec, sr=SAMPLE_RATE)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m audio = \u001b[43mlibrosa\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgriffinlim\u001b[49m\u001b[43m(\u001b[49m\u001b[43mS\u001b[49m\u001b[43m,\u001b[49m\u001b[43mn_iter\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m25\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mmomentum\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.99\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m audio = librosa.effects.preemphasis(audio) \u001b[38;5;66;03m# Reapply pre-emphasis\u001b[39;00m\n\u001b[32m 6\u001b[39m plt.figure(figsize=(\u001b[32m12\u001b[39m,\u001b[32m4\u001b[39m))\n", "\u001b[36mFile \u001b[39m\u001b[32m/nix/store/xb3jssmf8ghkxa2sib291dggp9m78rml-python3.12-librosa-0.11.0/lib/python3.12/site-packages/librosa/core/spectrum.py:2829\u001b[39m, in \u001b[36mgriffinlim\u001b[39m\u001b[34m(S, n_iter, hop_length, win_length, n_fft, window, center, dtype, length, pad_mode, momentum, init, random_state)\u001b[39m\n\u001b[32m 2816\u001b[39m inverse = istft(\n\u001b[32m 2817\u001b[39m angles,\n\u001b[32m 2818\u001b[39m hop_length=hop_length,\n\u001b[32m (...)\u001b[39m\u001b[32m 2825\u001b[39m out=inverse,\n\u001b[32m 2826\u001b[39m )\n\u001b[32m 2828\u001b[39m \u001b[38;5;66;03m# Rebuild the spectrogram\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m2829\u001b[39m rebuilt = \u001b[43mstft\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2830\u001b[39m \u001b[43m \u001b[49m\u001b[43minverse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2831\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_fft\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_fft\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2832\u001b[39m \u001b[43m \u001b[49m\u001b[43mhop_length\u001b[49m\u001b[43m=\u001b[49m\u001b[43mhop_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2833\u001b[39m \u001b[43m \u001b[49m\u001b[43mwin_length\u001b[49m\u001b[43m=\u001b[49m\u001b[43mwin_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2834\u001b[39m \u001b[43m \u001b[49m\u001b[43mwindow\u001b[49m\u001b[43m=\u001b[49m\u001b[43mwindow\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2835\u001b[39m \u001b[43m \u001b[49m\u001b[43mcenter\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcenter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2836\u001b[39m \u001b[43m \u001b[49m\u001b[43mpad_mode\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpad_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2837\u001b[39m \u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrebuilt\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2838\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2840\u001b[39m \u001b[38;5;66;03m# Update our phase estimates\u001b[39;00m\n\u001b[32m 2841\u001b[39m angles[:] = rebuilt\n", "\u001b[36mFile \u001b[39m\u001b[32m/nix/store/xb3jssmf8ghkxa2sib291dggp9m78rml-python3.12-librosa-0.11.0/lib/python3.12/site-packages/librosa/core/spectrum.py:239\u001b[39m, in \u001b[36mstft\u001b[39m\u001b[34m(y, n_fft, hop_length, win_length, window, center, dtype, pad_mode, out)\u001b[39m\n\u001b[32m 236\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ParameterError(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mhop_length=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhop_length\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m must be a positive integer\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 238\u001b[39m \u001b[38;5;66;03m# Check audio is valid\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m239\u001b[39m \u001b[43mutil\u001b[49m\u001b[43m.\u001b[49m\u001b[43mvalid_audio\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 241\u001b[39m fft_window = get_window(window, win_length, fftbins=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m 243\u001b[39m \u001b[38;5;66;03m# Pad the window out to n_fft size\u001b[39;00m\n", "\u001b[36mFile \u001b[39m\u001b[32m/nix/store/xb3jssmf8ghkxa2sib291dggp9m78rml-python3.12-librosa-0.11.0/lib/python3.12/site-packages/librosa/util/utils.py:298\u001b[39m, in \u001b[36mvalid_audio\u001b[39m\u001b[34m(y)\u001b[39m\n\u001b[32m 293\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ParameterError(\n\u001b[32m 294\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mAudio data must be at least one-dimensional, given y.shape=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00my.shape\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 295\u001b[39m )\n\u001b[32m 297\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np.isfinite(y).all():\n\u001b[32m--> \u001b[39m\u001b[32m298\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ParameterError(\u001b[33m\"\u001b[39m\u001b[33mAudio buffer is not finite everywhere\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 300\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n", "\u001b[31mParameterError\u001b[39m: Audio buffer is not finite everywhere" ] } ], "source": [ "for epoch in range(num_epochs):\n", " epoch_loss = 0.0\n", " indices = np.random.permutation(num_samples)\n", " \n", " for batch_idx in range(num_batches):\n", " start = batch_idx * batch_size\n", " end = start + batch_size\n", " batch_x = Tensor(x_np[indices[start:end]])\n", " batch_y = Tensor(y_np[indices[start:end]])\n", " \n", " loss = jit_step(batch_x, batch_y,(batch_idx % 4 == 0))\n", " epoch_loss += loss.item()\n", " \n", " avg_epoch_loss = epoch_loss / num_batches\n", " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "753ebcea-310b-43a0-a8d2-67994328d74a", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "nn.state.safe_save(nn.state.get_state_dict(model), \"vae_weights.safetensors\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }