mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
cleanup
This commit is contained in:
12
dockerfile
12
dockerfile
@@ -1,16 +1,20 @@
|
||||
# CUDA + cuDNN userspace from NVIDIA (no manual repo installs needed)
|
||||
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
||||
# CUDA + cuDNN userspace from NVIDIA (Ubuntu 22.04)
|
||||
FROM nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04
|
||||
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PIP_NO_CACHE_DIR=1
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_ROOT_USER_ACTION=ignore \
|
||||
HF_HUB_DISABLE_SYMLINKS_WARNING=1
|
||||
|
||||
# System deps (+dev headers for building C/C++ extensions)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.10 python3.10-venv python3.10-distutils python3.10-dev python3-pip \
|
||||
git wget curl unzip ca-certificates \
|
||||
git wget curl unzip ca-certificates git-lfs \
|
||||
build-essential g++ cmake \
|
||||
libsndfile1 libsndfile1-dev libffi-dev \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Use python3.10 everywhere
|
||||
|
||||
@@ -160,10 +160,12 @@
|
||||
"print(\"Linux/NVIDIA detected — using main piper-sample-generator repo.\")\n",
|
||||
"safe_clone(REPO_URL)\n",
|
||||
"\n",
|
||||
"# 2) Install deps (GPU ONNX)\n",
|
||||
"# - piper-phonemize-cross provides phonemization\n",
|
||||
"# - onnxruntime-gpu enables CUDA (container must have CUDA + drivers)\n",
|
||||
"# 2) Install deps\n",
|
||||
"# - piper-tts provides the `piper` module (required by generate_samples.py)\n",
|
||||
"# - piper-phonemize-cross does the phonemization\n",
|
||||
"# - onnxruntime-gpu enables CUDA (container must have NVIDIA runtime)\n",
|
||||
"deps = [\n",
|
||||
" \"piper-tts>=1.2.0\",\n",
|
||||
" \"piper-phonemize-cross==1.2.1\",\n",
|
||||
" \"soundfile\",\n",
|
||||
" \"numpy\",\n",
|
||||
@@ -195,6 +197,7 @@
|
||||
"cmd = [\n",
|
||||
" sys.executable, str(gen_script),\n",
|
||||
" TARGET_WORD,\n",
|
||||
" \"--model\", str(MODELS_DIR / MODEL_NAME), # ← pass the generator .pt explicitly\n",
|
||||
" \"--max-samples\", \"1\",\n",
|
||||
" \"--batch-size\", \"1\",\n",
|
||||
" \"--output-dir\", str(AUDIO_OUT_DIR),\n",
|
||||
@@ -217,17 +220,27 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generates a larger amount of wake word samples.\n",
|
||||
"# Start here when trying to improve your model.\n",
|
||||
"# See https://github.com/rhasspy/-m piper-sample-generator for the full set of\n",
|
||||
"# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n",
|
||||
"# generating negative samples similar to the wake word, and generating many more\n",
|
||||
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
||||
"# Generate a large number of wake word samples for training\n",
|
||||
"import sys, subprocess\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
||||
"--max-samples 50000 \\\n",
|
||||
"--batch-size 100 \\\n",
|
||||
"--output-dir generated_samples"
|
||||
"target_word = \"hey_tater\"\n",
|
||||
"REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n",
|
||||
"MODELS_DIR = REPO_DIR / \"models\"\n",
|
||||
"MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n",
|
||||
"\n",
|
||||
"cmd = [\n",
|
||||
" sys.executable,\n",
|
||||
" str(REPO_DIR / \"generate_samples.py\"),\n",
|
||||
" target_word,\n",
|
||||
" \"--model\", str(MODELS_DIR / MODEL_NAME), # important: specify generator .pt\n",
|
||||
" \"--max-samples\", \"50000\",\n",
|
||||
" \"--batch-size\", \"100\",\n",
|
||||
" \"--output-dir\", \"generated_samples\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(\"→\", \" \".join(cmd))\n",
|
||||
"subprocess.run(cmd, check=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -238,150 +251,132 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Downloads audio data for augmentation. This can be slow!\n",
|
||||
"# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n",
|
||||
"#\n",
|
||||
"# **Important note!** The data downloaded here has a mixture of difference\n",
|
||||
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
||||
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
||||
"# NVIDIA/Linux dataset prep to match the Apple behavior, but without datasets.Audio (no TorchCodec)\n",
|
||||
"# MIT RIR -> resample to 16 kHz\n",
|
||||
"# AudioSet -> NO resample\n",
|
||||
"# FMA -> resample to 16 kHz mono\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import scipy.io.wavfile\n",
|
||||
"import numpy as np\n",
|
||||
"from datasets import Dataset, Audio, load_dataset\n",
|
||||
"import os, sys, scipy.io.wavfile, numpy as np\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import soundfile as sf\n",
|
||||
"import librosa\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"def write_wav(dst: Path, data: np.ndarray, sr: int):\n",
|
||||
" x = np.clip(data, -1.0, 1.0)\n",
|
||||
" scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# Download and Process MIT RIR\n",
|
||||
"# MIT RIR (resample to 16 kHz)\n",
|
||||
"# -----------------------------\n",
|
||||
"output_dir = \"./mit_rirs\"\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.mkdir(output_dir)\n",
|
||||
" rir_dataset = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
|
||||
" print(f\"Downloading MIT RIR dataset to {output_dir}...\")\n",
|
||||
" for row in tqdm(rir_dataset):\n",
|
||||
" name = row[\"audio\"][\"path\"].split(\"/\")[-1]\n",
|
||||
" scipy.io.wavfile.write(\n",
|
||||
" os.path.join(output_dir, name), \n",
|
||||
" 16000, \n",
|
||||
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||
" )\n",
|
||||
" print(f\"Finished downloading MIT RIR dataset to {output_dir}.\\n\")\n",
|
||||
"else:\n",
|
||||
" print(f\"{output_dir} already exists. Skipping download.\")\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# Download and Process Audioset\n",
|
||||
"# -----------------------------\n",
|
||||
"\n",
|
||||
"# Directory setup\n",
|
||||
"audioset_dir = \"./audioset\"\n",
|
||||
"output_dir = \"./audioset_16k\"\n",
|
||||
"os.makedirs(audioset_dir, exist_ok=True)\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
"# Full-scale dataset download links\n",
|
||||
"dataset_links = [\n",
|
||||
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
||||
" for i in range(10)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Download and extract each dataset part\n",
|
||||
"for link in dataset_links:\n",
|
||||
" file_name = link.split(\"/\")[-1]\n",
|
||||
" out_path = os.path.join(audioset_dir, file_name)\n",
|
||||
" if not os.path.exists(out_path):\n",
|
||||
" print(f\"Downloading {file_name}...\")\n",
|
||||
" os.system(f\"wget --quiet -O {out_path} {link}\")\n",
|
||||
" print(f\"Extracting {file_name}...\")\n",
|
||||
" os.system(f\"tar -xf {out_path} -C {audioset_dir}\")\n",
|
||||
"\n",
|
||||
"# Collect all FLAC files for processing\n",
|
||||
"audioset_files = list(Path(audioset_dir).glob(\"**/*.flac\"))\n",
|
||||
"print(f\"Number of FLAC files found: {len(audioset_files)}\")\n",
|
||||
"\n",
|
||||
"if audioset_files:\n",
|
||||
" corrupted_files = []\n",
|
||||
"\n",
|
||||
" print(\"Converting Audioset files to 16kHz WAV...\")\n",
|
||||
" for file_path in tqdm(audioset_files, desc=\"Processing Audioset files\"):\n",
|
||||
"print(\"=== MIT RIR ===\")\n",
|
||||
"rir_out = Path(\"mit_rirs\")\n",
|
||||
"rir_out.mkdir(exist_ok=True)\n",
|
||||
"if not any(rir_out.rglob(\"*.wav\")):\n",
|
||||
" ok = 0\n",
|
||||
" try:\n",
|
||||
" # Attempt to load the file and handle any errors\n",
|
||||
" audio, sampling_rate = sf.read(file_path)\n",
|
||||
" \n",
|
||||
" if audio is None or len(audio) == 0:\n",
|
||||
" raise ValueError(f\"Empty or invalid audio data in file: {file_path}\")\n",
|
||||
"\n",
|
||||
" # Resample audio to 16kHz\n",
|
||||
" output_path = Path(output_dir) / (file_path.stem + \".wav\")\n",
|
||||
" scipy.io.wavfile.write(\n",
|
||||
" output_path,\n",
|
||||
" 16000,\n",
|
||||
" (audio * 32767).astype(np.int16),\n",
|
||||
" )\n",
|
||||
" except (sf.LibsndfileError, ValueError, Exception) as e:\n",
|
||||
" # Log the error and skip the file\n",
|
||||
" print(f\"Error converting {file_path}: {e}\")\n",
|
||||
" corrupted_files.append(str(file_path))\n",
|
||||
"\n",
|
||||
" # Log corrupted files\n",
|
||||
" if corrupted_files:\n",
|
||||
" log_path = Path(output_dir) / \"audioset_corrupted_files.log\"\n",
|
||||
" with open(log_path, \"w\") as log_file:\n",
|
||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||
" print(f\"Logged corrupted files to {log_path}\")\n",
|
||||
"else:\n",
|
||||
" print(\"No FLAC files found in Audioset.\")\n",
|
||||
"\n",
|
||||
"print(\"Audioset processing complete!\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# Download and Process FMA\n",
|
||||
"# -----------------------------\n",
|
||||
"output_dir = \"./fma\"\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.mkdir(output_dir)\n",
|
||||
" fname = \"fma_xs.zip\"\n",
|
||||
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
||||
" out_dir = os.path.join(output_dir, fname)\n",
|
||||
" os.system(f\"wget -q -O {out_dir} {link}\")\n",
|
||||
" os.system(f\"cd {output_dir} && unzip -q {fname}\")\n",
|
||||
"\n",
|
||||
"output_dir = \"./fma_16k\"\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.mkdir(output_dir)\n",
|
||||
"\n",
|
||||
"# Save clips to 16-bit PCM wav files\n",
|
||||
"fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
||||
"print(f\"Number of MP3 files found: {len(fma_files)}\")\n",
|
||||
"if fma_files:\n",
|
||||
" fma_dataset = Dataset.from_dict({\"audio\": [str(file) for file in fma_files]})\n",
|
||||
" fma_dataset = fma_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
|
||||
"\n",
|
||||
" corrupted_files = []\n",
|
||||
" print(\"Converting FMA files to 16kHz WAV...\")\n",
|
||||
" for row in tqdm(fma_dataset):\n",
|
||||
" # Avoid datasets.Audio to keep TorchCodec out:\n",
|
||||
" # Use streaming=True + Audio(decode=False)-equivalent: access raw file path and decode with librosa\n",
|
||||
" print(\"⬇️ MIT RIR (streaming + manual decode)…\")\n",
|
||||
" ds = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\",\n",
|
||||
" split=\"train\", streaming=True)\n",
|
||||
" for i, row in enumerate(tqdm(ds)):\n",
|
||||
" try:\n",
|
||||
" name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n",
|
||||
" scipy.io.wavfile.write(\n",
|
||||
" os.path.join(output_dir, name), \n",
|
||||
" 16000, \n",
|
||||
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||
" )\n",
|
||||
" audio_path = row[\"audio\"][\"path\"]\n",
|
||||
" y, sr = librosa.load(audio_path, sr=16000, mono=True)\n",
|
||||
" write_wav(rir_out / f\"rir_{i:04d}.wav\", y, 16000)\n",
|
||||
" ok += 1\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" print(f\"✅ MIT RIR saved: {ok} files\")\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error converting {row['audio']['path']}: {e}\")\n",
|
||||
" corrupted_files.append(row[\"audio\"][\"path\"])\n",
|
||||
"\n",
|
||||
" if corrupted_files:\n",
|
||||
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
|
||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||
" print(f\"⚠️ MIT RIR download failed: {e}\")\n",
|
||||
" # Fallback to official ZIP if needed (rare)\n",
|
||||
" try:\n",
|
||||
" print(\"⬇️ MIT RIR (fallback ZIP)…\")\n",
|
||||
" zip_url = \"https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip\"\n",
|
||||
" zip_path = rir_out.parent / \"MIT_RIR_Audio.zip\"\n",
|
||||
" if not zip_path.exists():\n",
|
||||
" os.system(f\"wget -q -O '{zip_path}' '{zip_url}'\")\n",
|
||||
" os.system(f'unzip -q -o \"{zip_path}\" -d \"{rir_out}\"')\n",
|
||||
" # Normalize to 16k mono\n",
|
||||
" for p in tqdm(list(rir_out.rglob(\"*.wav\")), desc=\"Normalize MIT RIR\"):\n",
|
||||
" a, sr = sf.read(p, always_2d=False)\n",
|
||||
" if a.ndim > 1: a = a[:,0]\n",
|
||||
" if sr != 16000:\n",
|
||||
" a, _ = librosa.load(p, sr=16000, mono=True)\n",
|
||||
" write_wav(p, a, 16000)\n",
|
||||
" print(\"✅ MIT RIR fallback complete\")\n",
|
||||
" except Exception as e2:\n",
|
||||
" print(f\"❌ MIT RIR fallback failed: {e2}\")\n",
|
||||
"else:\n",
|
||||
" print(\"No MP3 files found in FMA.\")\n",
|
||||
" print(\"✅ mit_rirs exists; skipping.\")\n",
|
||||
"\n",
|
||||
"print(\"Dataset preparation complete!\")"
|
||||
"# -----------------------------\n",
|
||||
"# AudioSet (NO resample — fast)\n",
|
||||
"# -----------------------------\n",
|
||||
"print(\"\\n=== AudioSet subset ===\")\n",
|
||||
"audioset_dir = Path(\"audioset\"); audioset_dir.mkdir(exist_ok=True)\n",
|
||||
"audioset_out = Path(\"audioset_16k\"); audioset_out.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"links = [f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
||||
" for i in range(10)]\n",
|
||||
"for link in links:\n",
|
||||
" fname = link.split(\"/\")[-1]\n",
|
||||
" out_tar = audioset_dir / fname\n",
|
||||
" if not out_tar.exists():\n",
|
||||
" print(f\"⬇️ {fname}\")\n",
|
||||
" os.system(f\"wget -q -O '{out_tar}' '{link}'\")\n",
|
||||
" print(f\"📦 Extract {fname}\")\n",
|
||||
" os.system(f\"tar -xf '{out_tar}' -C '{audioset_dir}'\")\n",
|
||||
"\n",
|
||||
"flacs = list(audioset_dir.rglob(\"*.flac\"))\n",
|
||||
"print(f\"🔎 FLAC files: {len(flacs)}\")\n",
|
||||
"corrupt = []\n",
|
||||
"for p in tqdm(flacs, desc=\"AudioSet→WAV (no resample)\"):\n",
|
||||
" try:\n",
|
||||
" a, sr = sf.read(p, always_2d=False)\n",
|
||||
" if a is None or len(a) == 0:\n",
|
||||
" raise ValueError(\"empty audio\")\n",
|
||||
" if a.ndim > 1:\n",
|
||||
" a = a[:,0]\n",
|
||||
" # Apple behavior: write as 16-bit and label 16 kHz (no resample)\n",
|
||||
" write_wav(audioset_out / (p.stem + \".wav\"), a, 16000)\n",
|
||||
" except Exception as e:\n",
|
||||
" corrupt.append(f\"{p}:{e}\")\n",
|
||||
"if corrupt:\n",
|
||||
" (audioset_out / \"audioset_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n",
|
||||
"print(\"✅ AudioSet processing complete!\")\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# FMA xsmall (resample to 16 kHz mono)\n",
|
||||
"# -----------------------------\n",
|
||||
"print(\"\\n=== FMA xsmall ===\")\n",
|
||||
"fma_zip_dir = Path(\"fma\"); fma_zip_dir.mkdir(exist_ok=True)\n",
|
||||
"fma_out = Path(\"fma_16k\"); fma_out.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"zipname = \"fma_xs.zip\"\n",
|
||||
"zipurl = f\"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{zipname}\"\n",
|
||||
"zipout = fma_zip_dir / zipname\n",
|
||||
"if not zipout.exists():\n",
|
||||
" os.system(f\"wget -q -O '{zipout}' '{zipurl}'\")\n",
|
||||
" os.system(f\"cd fma && unzip -q '{zipname}'\")\n",
|
||||
"\n",
|
||||
"mp3s = list(Path(\"fma/fma_small\").rglob(\"*.mp3\"))\n",
|
||||
"print(f\"🎵 FMA mp3 count: {len(mp3s)}\")\n",
|
||||
"corrupt = []\n",
|
||||
"for p in tqdm(mp3s, desc=\"FMA→16k WAV\"):\n",
|
||||
" try:\n",
|
||||
" y, sr = librosa.load(p, sr=16000, mono=True) # proper decode+resample\n",
|
||||
" if y.size == 0:\n",
|
||||
" raise ValueError(\"empty audio\")\n",
|
||||
" write_wav(fma_out / (p.stem + \".wav\"), y, 16000)\n",
|
||||
" except Exception as e:\n",
|
||||
" corrupt.append(f\"{p}:{e}\")\n",
|
||||
"if corrupt:\n",
|
||||
" Path(\"fma_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n",
|
||||
"print(\"\\n✅ Dataset prep complete!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -453,29 +448,41 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Augment a random clip and play it back to verify it works well\n",
|
||||
"# Augment a random generated-sample WAV and play it back (pass ndarray to augmenter)\n",
|
||||
"from pathlib import Path\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"from microwakeword.audio.audio_utils import save_clip\n",
|
||||
"from IPython.display import Audio, display\n",
|
||||
"import numpy as np\n",
|
||||
"import soundfile as sf\n",
|
||||
"import librosa, random, glob\n",
|
||||
"\n",
|
||||
"# Ensure output directory exists\n",
|
||||
"output_dir = Path('./augmented_clips')\n",
|
||||
"output_dir = Path(\"./augmented_clips\")\n",
|
||||
"output_dir.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"# 1) Pick a random WAV from the Piper outputs\n",
|
||||
"candidates = glob.glob(\"generated_samples/*.wav\")\n",
|
||||
"if not candidates:\n",
|
||||
" raise SystemExit(\"No files in generated_samples/. Run the TTS sample cell first.\")\n",
|
||||
"src_path = random.choice(candidates)\n",
|
||||
"\n",
|
||||
"# 2) Load as 16 kHz mono float32\n",
|
||||
"y, sr = librosa.load(src_path, sr=16000, mono=True)\n",
|
||||
"y = y.astype(np.float32, copy=False)\n",
|
||||
"\n",
|
||||
"# 3) Augment — microwakeword Augmentation expects a 1-D numpy array\n",
|
||||
"try:\n",
|
||||
" # Get a random clip and apply augmentation\n",
|
||||
" random_clip = clips.get_random_clip()\n",
|
||||
" augmented_clip = augmenter.augment_clip(random_clip)\n",
|
||||
" \n",
|
||||
" # Save augmented clip to file\n",
|
||||
" output_file = output_dir / 'augmented_clip.wav'\n",
|
||||
" save_clip(augmented_clip, output_file)\n",
|
||||
" print(f\"Augmented clip saved to {output_file}\")\n",
|
||||
" \n",
|
||||
" # Playback augmented clip\n",
|
||||
" display(Audio(str(output_file), autoplay=True))\n",
|
||||
" y_aug = augmenter.augment_clip(y)\n",
|
||||
"except Exception as e:\n",
|
||||
" print(f\"Error during augmentation or playback: {e}\")"
|
||||
" # some versions accept (samples, sr) — try that as a fallback\n",
|
||||
" try:\n",
|
||||
" y_aug = augmenter.augment_clip((y, sr))\n",
|
||||
" except Exception:\n",
|
||||
" raise\n",
|
||||
"\n",
|
||||
"# 4) Save and play\n",
|
||||
"out_path = output_dir / \"augmented_clip.wav\"\n",
|
||||
"sf.write(str(out_path), y_aug.astype(np.float32, copy=False), sr, subtype=\"PCM_16\")\n",
|
||||
"print(f\"Augmented clip saved to {out_path}\")\n",
|
||||
"display(Audio(str(out_path), autoplay=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -487,53 +494,96 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Augment samples and save the training, validation, and testing sets.\n",
|
||||
"# Validating and testing samples generated the same way can make the model\n",
|
||||
"# benchmark better than it performs in real-word use. Use real samples or TTS\n",
|
||||
"# samples generated with a different TTS engine to potentially get more accurate\n",
|
||||
"# benchmarks.\n",
|
||||
"# This version avoids datasets.Audio entirely by driving Clips from local WAVs.\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import os, glob, random\n",
|
||||
"from pathlib import Path\n",
|
||||
"import types\n",
|
||||
"import numpy as np\n",
|
||||
"import librosa\n",
|
||||
"from mmap_ninja.ragged import RaggedMmap\n",
|
||||
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||||
"\n",
|
||||
"# Output directory for augmented features\n",
|
||||
"output_dir = 'generated_augmented_features'\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n",
|
||||
"def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n",
|
||||
" \"\"\"\n",
|
||||
" Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n",
|
||||
" Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n",
|
||||
" \"\"\"\n",
|
||||
" files = sorted(glob.glob(\"generated_samples/*.wav\"))\n",
|
||||
" if not files:\n",
|
||||
" raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n",
|
||||
"\n",
|
||||
"# Configuration for each split\n",
|
||||
"split_config = {\n",
|
||||
" rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n",
|
||||
" files_shuf = files[:]\n",
|
||||
" rng.shuffle(files_shuf)\n",
|
||||
"\n",
|
||||
" n = len(files_shuf)\n",
|
||||
" n_val = max(1, int(0.10 * n))\n",
|
||||
" n_test = max(1, int(0.10 * n))\n",
|
||||
" n_train = max(0, n - n_val - n_test)\n",
|
||||
" splits = {\n",
|
||||
" \"train\": files_shuf[:n_train],\n",
|
||||
" \"validation\": files_shuf[n_train:n_train + n_val],\n",
|
||||
" \"test\": files_shuf[n_train + n_val:],\n",
|
||||
" }\n",
|
||||
" file_list = splits.get(split, [])\n",
|
||||
" if not file_list:\n",
|
||||
" return # nothing to yield\n",
|
||||
"\n",
|
||||
" for _ in range(max(1, int(repeat))):\n",
|
||||
" for p in file_list:\n",
|
||||
" y, sr = librosa.load(p, sr=16000, mono=True)\n",
|
||||
" yield y.astype(np.float32, copy=False)\n",
|
||||
"\n",
|
||||
"# Bind the patched generator to your existing `clips` instance\n",
|
||||
"clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n",
|
||||
"print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n",
|
||||
"\n",
|
||||
"# ---- Validate augmentation asset folders exist ----\n",
|
||||
"def validate(paths):\n",
|
||||
" for p in paths:\n",
|
||||
" if not Path(p).exists():\n",
|
||||
" raise SystemExit(f\"❌ Missing directory: {p}. Run dataset prep first.\")\n",
|
||||
"\n",
|
||||
"impulse_paths = [\"mit_rirs\"]\n",
|
||||
"background_paths = [\"fma_16k\", \"audioset_16k\"]\n",
|
||||
"validate(impulse_paths + background_paths)\n",
|
||||
"\n",
|
||||
"# ---- Output root ----\n",
|
||||
"out_root = Path(\"generated_augmented_features\")\n",
|
||||
"out_root.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"# ---- Split config (same as before) ----\n",
|
||||
"split_cfg = {\n",
|
||||
" \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n",
|
||||
" \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n",
|
||||
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Generate augmented features for each split\n",
|
||||
"for split, config in split_config.items():\n",
|
||||
" out_dir = os.path.join(output_dir, split)\n",
|
||||
" os.makedirs(out_dir, exist_ok=True)\n",
|
||||
" print(f\"Processing {split} set...\")\n",
|
||||
"# ---- Generate features ----\n",
|
||||
"for split, cfg in split_cfg.items():\n",
|
||||
" out_dir = out_root / split\n",
|
||||
" out_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print(f\"🧪 Processing {split} …\")\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" # Spectrogram generation configuration\n",
|
||||
" spectrograms = SpectrogramGeneration(\n",
|
||||
" clips=clips,\n",
|
||||
" augmenter=augmenter,\n",
|
||||
" slide_frames=config[\"slide_frames\"],\n",
|
||||
" step_ms=10, # Can parameterize this if needed\n",
|
||||
" spectros = SpectrogramGeneration(\n",
|
||||
" clips=clips, # now backed by our WAV loader\n",
|
||||
" augmenter=augmenter, # your existing augmenter\n",
|
||||
" slide_frames=cfg[\"slide_frames\"],\n",
|
||||
" step_ms=10,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Generate and save spectrogram features\n",
|
||||
" RaggedMmap.from_generator(\n",
|
||||
" out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n",
|
||||
" sample_generator=spectrograms.spectrogram_generator(\n",
|
||||
" split=config[\"name\"], repeat=config[\"repetition\"]\n",
|
||||
" out_dir=str(out_dir / \"wakeword_mmap\"),\n",
|
||||
" sample_generator=spectros.spectrogram_generator(\n",
|
||||
" split=cfg[\"name\"], repeat=cfg[\"repetition\"]\n",
|
||||
" ),\n",
|
||||
" batch_size=100, # Can parameterize this if needed\n",
|
||||
" batch_size=100,\n",
|
||||
" verbose=True,\n",
|
||||
" )\n",
|
||||
" print(f\"Completed processing {split} set. Output saved to {out_dir}\")\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error processing {split} set: {e}\")"
|
||||
"\n",
|
||||
"print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -808,7 +858,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.18"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
# --- Core training (Microwakeword) ---
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
tensorboard==2.18.0
|
||||
tensorboard-data-server==0.7.2
|
||||
tensorflow==2.18.0
|
||||
tensorflow-estimator==2.13.0
|
||||
tensorflow-io-gcs-filesystem==0.37.1
|
||||
numpy==1.26.4
|
||||
scipy==1.12.0
|
||||
librosa==0.10.2.post1
|
||||
@@ -13,16 +18,16 @@ scikit-learn==1.6.0
|
||||
numba==0.60.0
|
||||
joblib==1.4.2
|
||||
pandas==2.2.3
|
||||
# feature extractors + metadata helpers your repo uses
|
||||
pymicro_features @ git+https://github.com/puddly/pymicro-features@e1d3f88183e12bb8af2df9e399ea157af7393762
|
||||
audio-metadata @ git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f
|
||||
bitstruct==8.19.0
|
||||
|
||||
# --- Piper sample generation ---
|
||||
piper-tts>=1.2.0
|
||||
onnxruntime-gpu>=1.16.0
|
||||
piper-phonemize-cross==1.2.1
|
||||
|
||||
# --- Notebook / tooling (keep light) ---
|
||||
# --- Notebook / tooling ---
|
||||
ipykernel==6.29.5
|
||||
jupyterlab==4.3.4
|
||||
ipywidgets==8.1.5
|
||||
|
||||
Reference in New Issue
Block a user