cleanup/rewrite

This commit is contained in:
k 2025-07-31 14:23:34 -04:00
parent f1abc67462
commit 689e4df4aa
5 changed files with 277 additions and 482 deletions

65
data.py Normal file
View File

@ -0,0 +1,65 @@
import librosa
import numpy as np
from pathlib import Path
from multiprocessing import Pool, cpu_count
SAMPLE_RATE = 22050
def process_file(file_path):
"""
Load 10 second chunks single song.
"""
y, sr = librosa.load(file_path, mono=True, sr=SAMPLE_RATE)
size = int(SAMPLE_RATE * 10)
sample_len = len(y)
file_chunks = []
for start_pos in range(0, sample_len, size):
end = start_pos + size
if end <= sample_len:
chunk = y[start_pos:end]
file_chunks.append(chunk)
return file_chunks
def load():
"""
Load 10 second chunks of songs.
"""
audio = []
files = list(Path("./data/").glob("*.mp3"))
with Pool(cpu_count()) as pool:
chunk_list = pool.map(process_file, files)
for l in chunk_list:
audio.extend(l)
return audio
def audio_split(audio):
"""
Split 10 seconds of audio to 2 5 second clips
"""
size = int(SAMPLE_RATE*5)
x = audio[:size]
y = audio[size:size*2]
x = librosa.feature.melspectrogram(y=x, sr=SAMPLE_RATE)
y = librosa.feature.melspectrogram(y=y, sr=SAMPLE_RATE)
ma,mi = x.max(), x.min()
x = (x - mi) / (ma - mi)
ma,mi = y.max(), y.min()
y = (y - mi) / (ma - mi)
return x,y
def detaset(chunks):
"""
convert 10 second chunks to dataset
"""
x,y=[],[]
with Pool(cpu_count()) as pool:
audio_list = pool.map(audio_split,chunks)
for (ax,ay) in audio_list:
x.append(ax)
y.append(ay)
return x,y

View File

@ -1,11 +1,11 @@
from tinygrad import Tensor, nn
import numpy as np
class Model:
class Gen:
def __init__(self, input_channels=1, height=128, width=216, latent_dim=32):
self.w = width // 8
self.h = height // 8
self.flattened_size = 128 * self.h * self.w
self.flattened_size = 256 * self.h * self.w
# Encoder
self.e1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
@ -61,5 +61,24 @@ class Model:
x = x.reshape(shape=(-1, 256, self.h, self.w))
x = self.d1(x).relu()
x = self.d2(x).relu()
x = self.d3(x)
return x
x = self.d3(x).sigmoid()
return x
class Check():
def __init__(self, input_channels=1, height=128, width=216):
self.w = width // 8
self.h = height // 8
self.flattened_size = 256 * self.h * self.w
self.d1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
self.d2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.d3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.fc = nn.Linear(self.flattened_size, 1)
def __call__(self, x: Tensor) -> Tensor:
x = self.d1(x).leakyrelu(0.2)
x = self.d2(x).leakyrelu(0.2)
x = self.d3(x).leakyrelu(0.2)
x = x.reshape(shape=(-1, self.flattened_size))
return self.fc(x)

File diff suppressed because one or more lines are too long

26
show.py Normal file
View File

@ -0,0 +1,26 @@
import matplotlib.pyplot as plt
import IPython.display as ipd
import librosa
SAMPLE_RATE = 22050
def showSpec(spec):
plt.figure(figsize=(10, 4))
librosa.display.specshow(spec, sr=SAMPLE_RATE,
x_axis='time', y_axis='mel',
cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title('Mel spectrogram')
plt.show()
def playSpec(spec):
S = librosa.feature.inverse.mel_to_stft(spec, sr=SAMPLE_RATE)
audio = librosa.griffinlim(S,n_iter=25,momentum=0.99)
plt.figure(figsize=(12,4))
plt.plot(audio)
plt.title('waveform')
plt.show()
display(ipd.Audio(audio,rate=SAMPLE_RATE))

163
train.ipynb Normal file
View File

@ -0,0 +1,163 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import data\n",
"import show\n",
"import model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"x,y = data.detaset(data.load())\n",
"len(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"sample=x[420]\n",
"show.showSpec(sample)\n",
"show.playSpec(sample)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"from tinygrad import nn\n",
"gen = model.Gen()\n",
"optimizer = nn.optim.AdamW(nn.state.get_parameters(gen), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"@TinyJit\n",
"def jit_step(X: Tensor, Y: Tensor,epoch) -> Tensor:\n",
" Tensor.training = True\n",
" optimizer.zero_grad()\n",
" _, loss = gen.__Lcall__(X,Y,epoch)\n",
" loss.backward()\n",
" optimizer.step()\n",
" return loss.realize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"x_np, y_np = np.stack(x), np.stack(y)\n",
"x_np = np.expand_dims(x_np, axis=1)\n",
"y_np = np.expand_dims(y_np, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"size=len(x)\n",
"BACH_SIZE=32\n",
"eshape = (BACH_SIZE, 1, 128, 216)\n",
"\n",
"for e in range(0,12):\n",
" print(f\"\\n--- Starting Epoch {e} ---\\n\")\n",
" l=0\n",
" \n",
" permutation = np.random.permutation(size)\n",
" x_np = x_np[permutation]\n",
" y_np = y_np[permutation]\n",
" \n",
" for i in range(0,size,BACH_SIZE):\n",
" tx,ty=Tensor(x_np[i:i+BACH_SIZE]),Tensor(y_np[i:i+BACH_SIZE])\n",
" if(tx.shape != eshape or ty.shape != eshape):\n",
" continue\n",
" l+=jit_step(tx,ty,e).numpy()\n",
" \n",
" l /= (size/BACH_SIZE)\n",
" print(f\"loss of {l}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"from tinygrad.nn.state import safe_save, get_state_dict\n",
"safe_save(get_state_dict(gen),\"music.safetensors\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}