mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
867 lines
34 KiB
Plaintext
867 lines
34 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "r11cNiLqvWC6"
|
|
},
|
|
"source": [
|
|
"<div align=\"center\">\n",
|
|
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
|
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
|
|
"\n",
|
|
"**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**\n",
|
|
"\n",
|
|
"Throughout the notebook, you will find comments suggesting specific settings to modify and experiment with to enhance your model's performance.\n",
|
|
"\n",
|
|
"By the end of this notebook, you will have:\n",
|
|
"- A trained TensorFlow Lite model ready for deployment.\n",
|
|
"- A JSON manifest file to integrate the model with ESPHome.\n",
|
|
"\n",
|
|
"To use the generated model in ESPHome, refer to the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for integration details. You can also explore example configurations in the [model repository](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "BFf6511E65ff"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Installs microWakeWord. Be sure to restart the session after this is finished.\n",
|
|
"import platform\n",
|
|
"import sys\n",
|
|
"import os\n",
|
|
"\n",
|
|
"if platform.system() == \"Darwin\":\n",
|
|
" # `pymicro-features` is installed from a fork to support building on macOS\n",
|
|
" !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n",
|
|
"\n",
|
|
"# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n",
|
|
"!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n",
|
|
"\n",
|
|
"# Clone the microWakeWord repository\n",
|
|
"repo_path = \"./microWakeWord\"\n",
|
|
"if not os.path.exists(repo_path):\n",
|
|
" print(\"Cloning microWakeWord repository...\")\n",
|
|
" !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n",
|
|
"\n",
|
|
"# Ensure the repository exists before attempting to install\n",
|
|
"if os.path.exists(repo_path):\n",
|
|
" print(\"Installing microWakeWord...\")\n",
|
|
" !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n",
|
|
"else:\n",
|
|
" print(f\"Repository not found at {repo_path}. Cloning might have failed.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "BFf6511E65ff"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# --- GPU Check (Torch + ONNX Runtime) ---\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import onnxruntime as ort\n",
|
|
"\n",
|
|
"print(\"🔧 Torch CUDA Available:\", torch.cuda.is_available())\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" print(\" • Device count:\", torch.cuda.device_count())\n",
|
|
" print(\" • Current device:\", torch.cuda.current_device())\n",
|
|
" print(\" • Device name:\", torch.cuda.get_device_name(torch.cuda.current_device()))\n",
|
|
"else:\n",
|
|
" print(\"⚠️ Torch cannot see a GPU — check Docker runtime (--gpus all) and nvidia-container-toolkit\")\n",
|
|
"\n",
|
|
"print(\"\\n🔧 ONNX Runtime Providers:\")\n",
|
|
"try:\n",
|
|
" providers = ort.get_available_providers()\n",
|
|
" print(\" •\", providers)\n",
|
|
" if \"CUDAExecutionProvider\" not in providers:\n",
|
|
" print(\"⚠️ CUDAExecutionProvider not available — ONNX will fall back to CPU.\")\n",
|
|
"except Exception as e:\n",
|
|
" print(\"⚠️ Could not query ONNX Runtime providers:\", e)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "dEluu7nL7ywd"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NVIDIA Linux Docker: generate 1 sample of the target word (robust + CUDA check)\n",
|
|
"\n",
|
|
"import os, sys, shutil, subprocess, time, platform\n",
|
|
"from pathlib import Path\n",
|
|
"from IPython.display import Audio, display\n",
|
|
"\n",
|
|
"TARGET_WORD = \"hey_tater\"\n",
|
|
"REPO_URL = \"https://github.com/rhasspy/piper-sample-generator\"\n",
|
|
"REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n",
|
|
"MODELS_DIR = REPO_DIR / \"models\"\n",
|
|
"MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n",
|
|
"MODEL_URL = f\"https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/{MODEL_NAME}\"\n",
|
|
"AUDIO_OUT_DIR = Path.cwd() / \"generated_samples\"\n",
|
|
"AUDIO_PATH = AUDIO_OUT_DIR / \"0.wav\"\n",
|
|
"\n",
|
|
"def run(cmd, check=True):\n",
|
|
" print(\"→\", \" \".join(cmd))\n",
|
|
" proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n",
|
|
" for line in proc.stdout:\n",
|
|
" print(line, end=\"\")\n",
|
|
" rc = proc.wait()\n",
|
|
" if check and rc != 0:\n",
|
|
" raise RuntimeError(f\"Command failed with exit code {rc}: {' '.join(cmd)}\")\n",
|
|
" return rc\n",
|
|
"\n",
|
|
"def pip_install(*pkgs):\n",
|
|
" run([sys.executable, \"-m\", \"pip\", \"install\", \"--upgrade\", \"pip\"], check=False)\n",
|
|
" run([sys.executable, \"-m\", \"pip\", \"install\", *pkgs])\n",
|
|
"\n",
|
|
"def safe_clone(repo_url, branch=None, dest=REPO_DIR, retries=2):\n",
|
|
" if dest.exists() and not (dest / \".git\").exists():\n",
|
|
" print(\"⚠️ Found partial clone. Removing…\")\n",
|
|
" shutil.rmtree(dest, ignore_errors=True)\n",
|
|
" if not dest.exists():\n",
|
|
" for i in range(retries + 1):\n",
|
|
" try:\n",
|
|
" cmd = [\"git\", \"clone\", \"--depth\", \"1\", repo_url, str(dest)]\n",
|
|
" if branch:\n",
|
|
" cmd = [\"git\", \"clone\", \"--depth\", \"1\", \"--branch\", branch, repo_url, str(dest)]\n",
|
|
" run(cmd)\n",
|
|
" break\n",
|
|
" except Exception as e:\n",
|
|
" if i == retries:\n",
|
|
" raise\n",
|
|
" print(f\"Clone failed ({i+1}/{retries+1}). Retrying in 2s… [{e}]\")\n",
|
|
" time.sleep(2)\n",
|
|
"\n",
|
|
"def ensure_model():\n",
|
|
" MODELS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
" mp = MODELS_DIR / MODEL_NAME\n",
|
|
" if not mp.exists() or mp.stat().st_size == 0:\n",
|
|
" import urllib.request\n",
|
|
" print(f\"Downloading model to {mp} …\")\n",
|
|
" with urllib.request.urlopen(MODEL_URL) as r, open(mp, \"wb\") as f:\n",
|
|
" shutil.copyfileobj(r, f)\n",
|
|
" if mp.stat().st_size < 100 * 1024:\n",
|
|
" raise RuntimeError(\"Downloaded model looks too small; download may have failed.\")\n",
|
|
" print(f\"✅ Model ready: {mp}\")\n",
|
|
"\n",
|
|
"# 1) Clone main repo (Linux/NVIDIA)\n",
|
|
"print(\"Linux/NVIDIA detected — using main piper-sample-generator repo.\")\n",
|
|
"safe_clone(REPO_URL)\n",
|
|
"\n",
|
|
"# 2) Install deps\n",
|
|
"# - piper-tts provides the `piper` module (required by generate_samples.py)\n",
|
|
"# - piper-phonemize-cross does the phonemization\n",
|
|
"# - onnxruntime-gpu enables CUDA (container must have NVIDIA runtime)\n",
|
|
"deps = [\n",
|
|
" \"piper-tts>=1.2.0\",\n",
|
|
" \"piper-phonemize-cross==1.2.1\",\n",
|
|
" \"soundfile\",\n",
|
|
" \"numpy\",\n",
|
|
" \"onnxruntime-gpu>=1.16.0\",\n",
|
|
"]\n",
|
|
"pip_install(*deps)\n",
|
|
"\n",
|
|
"# 3) Verify CUDA provider is available\n",
|
|
"try:\n",
|
|
" import onnxruntime as ort\n",
|
|
" providers = ort.get_available_providers()\n",
|
|
" print(f\"ONNX Runtime providers: {providers}\")\n",
|
|
" if \"CUDAExecutionProvider\" not in providers:\n",
|
|
" print(\"⚠️ CUDAExecutionProvider not available. \"\n",
|
|
" \"The sample will still run on CPU, but check your NVIDIA container setup \"\n",
|
|
" \"(nvidia-container-toolkit, runtime, and driver).\")\n",
|
|
"except Exception as e:\n",
|
|
" print(\"⚠️ Could not import onnxruntime to verify providers:\", e)\n",
|
|
"\n",
|
|
"# 4) Ensure model present\n",
|
|
"ensure_model()\n",
|
|
"\n",
|
|
"# 5) Generate one sample\n",
|
|
"AUDIO_OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
"gen_script = REPO_DIR / \"generate_samples.py\"\n",
|
|
"if not gen_script.exists():\n",
|
|
" raise FileNotFoundError(f\"Missing generator: {gen_script}\")\n",
|
|
"\n",
|
|
"cmd = [\n",
|
|
" sys.executable, str(gen_script),\n",
|
|
" TARGET_WORD,\n",
|
|
" \"--model\", str(MODELS_DIR / MODEL_NAME), # ← pass the generator .pt explicitly\n",
|
|
" \"--max-samples\", \"1\",\n",
|
|
" \"--batch-size\", \"1\",\n",
|
|
" \"--output-dir\", str(AUDIO_OUT_DIR),\n",
|
|
"]\n",
|
|
"run(cmd)\n",
|
|
"\n",
|
|
"# 6) Play the audio (if the notebook frontend supports it)\n",
|
|
"if AUDIO_PATH.exists():\n",
|
|
" print(f\"🎧 Playing {AUDIO_PATH}\")\n",
|
|
" display(Audio(str(AUDIO_PATH), autoplay=True))\n",
|
|
"else:\n",
|
|
" print(f\"Audio file not found at {AUDIO_PATH}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "-SvGtCCM9akR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a large number of wake word samples for training\n",
|
|
"import sys, subprocess\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"target_word = \"hey_tater\"\n",
|
|
"REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n",
|
|
"MODELS_DIR = REPO_DIR / \"models\"\n",
|
|
"MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n",
|
|
"\n",
|
|
"cmd = [\n",
|
|
" sys.executable,\n",
|
|
" str(REPO_DIR / \"generate_samples.py\"),\n",
|
|
" target_word,\n",
|
|
" \"--model\", str(MODELS_DIR / MODEL_NAME), # important: specify generator .pt\n",
|
|
" \"--max-samples\", \"50000\",\n",
|
|
" \"--batch-size\", \"100\",\n",
|
|
" \"--output-dir\", \"generated_samples\",\n",
|
|
"]\n",
|
|
"\n",
|
|
"print(\"→\", \" \".join(cmd))\n",
|
|
"subprocess.run(cmd, check=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "YJRG4Qvo9nXG"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NVIDIA/Linux dataset prep to match the Apple behavior, but without datasets.Audio (no TorchCodec)\n",
|
|
"# MIT RIR -> resample to 16 kHz\n",
|
|
"# AudioSet -> NO resample\n",
|
|
"# FMA -> resample to 16 kHz mono\n",
|
|
"\n",
|
|
"import os, sys, scipy.io.wavfile, numpy as np\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm import tqdm\n",
|
|
"import soundfile as sf\n",
|
|
"import librosa\n",
|
|
"from datasets import load_dataset\n",
|
|
"\n",
|
|
"def write_wav(dst: Path, data: np.ndarray, sr: int):\n",
|
|
" x = np.clip(data, -1.0, 1.0)\n",
|
|
" scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))\n",
|
|
"\n",
|
|
"# -----------------------------\n",
|
|
"# MIT RIR (resample to 16 kHz)\n",
|
|
"# -----------------------------\n",
|
|
"print(\"=== MIT RIR ===\")\n",
|
|
"rir_out = Path(\"mit_rirs\")\n",
|
|
"rir_out.mkdir(exist_ok=True)\n",
|
|
"if not any(rir_out.rglob(\"*.wav\")):\n",
|
|
" ok = 0\n",
|
|
" try:\n",
|
|
" # Avoid datasets.Audio to keep TorchCodec out:\n",
|
|
" # Use streaming=True + Audio(decode=False)-equivalent: access raw file path and decode with librosa\n",
|
|
" print(\"⬇️ MIT RIR (streaming + manual decode)…\")\n",
|
|
" ds = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\",\n",
|
|
" split=\"train\", streaming=True)\n",
|
|
" for i, row in enumerate(tqdm(ds)):\n",
|
|
" try:\n",
|
|
" audio_path = row[\"audio\"][\"path\"]\n",
|
|
" y, sr = librosa.load(audio_path, sr=16000, mono=True)\n",
|
|
" write_wav(rir_out / f\"rir_{i:04d}.wav\", y, 16000)\n",
|
|
" ok += 1\n",
|
|
" except Exception:\n",
|
|
" pass\n",
|
|
" print(f\"✅ MIT RIR saved: {ok} files\")\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"⚠️ MIT RIR download failed: {e}\")\n",
|
|
" # Fallback to official ZIP if needed (rare)\n",
|
|
" try:\n",
|
|
" print(\"⬇️ MIT RIR (fallback ZIP)…\")\n",
|
|
" zip_url = \"https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip\"\n",
|
|
" zip_path = rir_out.parent / \"MIT_RIR_Audio.zip\"\n",
|
|
" if not zip_path.exists():\n",
|
|
" os.system(f\"wget -q -O '{zip_path}' '{zip_url}'\")\n",
|
|
" os.system(f'unzip -q -o \"{zip_path}\" -d \"{rir_out}\"')\n",
|
|
" # Normalize to 16k mono\n",
|
|
" for p in tqdm(list(rir_out.rglob(\"*.wav\")), desc=\"Normalize MIT RIR\"):\n",
|
|
" a, sr = sf.read(p, always_2d=False)\n",
|
|
" if a.ndim > 1: a = a[:,0]\n",
|
|
" if sr != 16000:\n",
|
|
" a, _ = librosa.load(p, sr=16000, mono=True)\n",
|
|
" write_wav(p, a, 16000)\n",
|
|
" print(\"✅ MIT RIR fallback complete\")\n",
|
|
" except Exception as e2:\n",
|
|
" print(f\"❌ MIT RIR fallback failed: {e2}\")\n",
|
|
"else:\n",
|
|
" print(\"✅ mit_rirs exists; skipping.\")\n",
|
|
"\n",
|
|
"# -----------------------------\n",
|
|
"# AudioSet (NO resample — fast)\n",
|
|
"# -----------------------------\n",
|
|
"print(\"\\n=== AudioSet subset ===\")\n",
|
|
"audioset_dir = Path(\"audioset\"); audioset_dir.mkdir(exist_ok=True)\n",
|
|
"audioset_out = Path(\"audioset_16k\"); audioset_out.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"links = [f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
|
" for i in range(10)]\n",
|
|
"for link in links:\n",
|
|
" fname = link.split(\"/\")[-1]\n",
|
|
" out_tar = audioset_dir / fname\n",
|
|
" if not out_tar.exists():\n",
|
|
" print(f\"⬇️ {fname}\")\n",
|
|
" os.system(f\"wget -q -O '{out_tar}' '{link}'\")\n",
|
|
" print(f\"📦 Extract {fname}\")\n",
|
|
" os.system(f\"tar -xf '{out_tar}' -C '{audioset_dir}'\")\n",
|
|
"\n",
|
|
"flacs = list(audioset_dir.rglob(\"*.flac\"))\n",
|
|
"print(f\"🔎 FLAC files: {len(flacs)}\")\n",
|
|
"corrupt = []\n",
|
|
"for p in tqdm(flacs, desc=\"AudioSet→WAV (no resample)\"):\n",
|
|
" try:\n",
|
|
" a, sr = sf.read(p, always_2d=False)\n",
|
|
" if a is None or len(a) == 0:\n",
|
|
" raise ValueError(\"empty audio\")\n",
|
|
" if a.ndim > 1:\n",
|
|
" a = a[:,0]\n",
|
|
" # Apple behavior: write as 16-bit and label 16 kHz (no resample)\n",
|
|
" write_wav(audioset_out / (p.stem + \".wav\"), a, 16000)\n",
|
|
" except Exception as e:\n",
|
|
" corrupt.append(f\"{p}:{e}\")\n",
|
|
"if corrupt:\n",
|
|
" (audioset_out / \"audioset_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n",
|
|
"print(\"✅ AudioSet processing complete!\")\n",
|
|
"\n",
|
|
"# -----------------------------\n",
|
|
"# FMA xsmall (resample to 16 kHz mono)\n",
|
|
"# -----------------------------\n",
|
|
"print(\"\\n=== FMA xsmall ===\")\n",
|
|
"fma_zip_dir = Path(\"fma\"); fma_zip_dir.mkdir(exist_ok=True)\n",
|
|
"fma_out = Path(\"fma_16k\"); fma_out.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"zipname = \"fma_xs.zip\"\n",
|
|
"zipurl = f\"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{zipname}\"\n",
|
|
"zipout = fma_zip_dir / zipname\n",
|
|
"if not zipout.exists():\n",
|
|
" os.system(f\"wget -q -O '{zipout}' '{zipurl}'\")\n",
|
|
" os.system(f\"cd fma && unzip -q '{zipname}'\")\n",
|
|
"\n",
|
|
"mp3s = list(Path(\"fma/fma_small\").rglob(\"*.mp3\"))\n",
|
|
"print(f\"🎵 FMA mp3 count: {len(mp3s)}\")\n",
|
|
"corrupt = []\n",
|
|
"for p in tqdm(mp3s, desc=\"FMA→16k WAV\"):\n",
|
|
" try:\n",
|
|
" y, sr = librosa.load(p, sr=16000, mono=True) # proper decode+resample\n",
|
|
" if y.size == 0:\n",
|
|
" raise ValueError(\"empty audio\")\n",
|
|
" write_wav(fma_out / (p.stem + \".wav\"), y, 16000)\n",
|
|
" except Exception as e:\n",
|
|
" corrupt.append(f\"{p}:{e}\")\n",
|
|
"if corrupt:\n",
|
|
" Path(\"fma_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n",
|
|
"print(\"\\n✅ Dataset prep complete!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "XW3bmbI5-JAz"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Sets up the augmentations.\n",
|
|
"# To improve your model, experiment with these settings and use more sources of\n",
|
|
"# background clips.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from microwakeword.audio.augmentation import Augmentation\n",
|
|
"from microwakeword.audio.clips import Clips\n",
|
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
|
"\n",
|
|
"def validate_directories(paths):\n",
|
|
" for path in paths:\n",
|
|
" if not os.path.exists(path):\n",
|
|
" print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n",
|
|
" return False\n",
|
|
" return True\n",
|
|
"\n",
|
|
"# Paths to augmented data\n",
|
|
"impulse_paths = ['mit_rirs']\n",
|
|
"background_paths = ['fma_16k', 'audioset_16k']\n",
|
|
"\n",
|
|
"if not validate_directories(impulse_paths + background_paths):\n",
|
|
" raise ValueError(\"One or more required directories are missing.\")\n",
|
|
"\n",
|
|
"clips = Clips(\n",
|
|
" input_directory='./generated_samples',\n",
|
|
" file_pattern='*.wav',\n",
|
|
" max_clip_duration_s=5,\n",
|
|
" remove_silence=True,\n",
|
|
" random_split_seed=10,\n",
|
|
" split_count=0.1,\n",
|
|
")\n",
|
|
"\n",
|
|
"augmenter = Augmentation(\n",
|
|
" augmentation_duration_s=3.2,\n",
|
|
" augmentation_probabilities={\n",
|
|
" \"SevenBandParametricEQ\": 0.1,\n",
|
|
" \"TanhDistortion\": 0.05,\n",
|
|
" \"PitchShift\": 0.15,\n",
|
|
" \"BandStopFilter\": 0.1,\n",
|
|
" \"AddColorNoise\": 0.1,\n",
|
|
" \"AddBackgroundNoise\": 0.7,\n",
|
|
" \"Gain\": 0.8,\n",
|
|
" \"RIR\": 0.7,\n",
|
|
" },\n",
|
|
" impulse_paths=impulse_paths,\n",
|
|
" background_paths=background_paths,\n",
|
|
" background_min_snr_db=5,\n",
|
|
" background_max_snr_db=10,\n",
|
|
" min_jitter_s=0.2,\n",
|
|
" max_jitter_s=0.3,\n",
|
|
")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "V5UsJfKKD1k9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Augment a random generated-sample WAV and play it back (pass ndarray to augmenter)\n",
|
|
"from pathlib import Path\n",
|
|
"from IPython.display import Audio, display\n",
|
|
"import numpy as np\n",
|
|
"import soundfile as sf\n",
|
|
"import librosa, random, glob\n",
|
|
"\n",
|
|
"output_dir = Path(\"./augmented_clips\")\n",
|
|
"output_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"# 1) Pick a random WAV from the Piper outputs\n",
|
|
"candidates = glob.glob(\"generated_samples/*.wav\")\n",
|
|
"if not candidates:\n",
|
|
" raise SystemExit(\"No files in generated_samples/. Run the TTS sample cell first.\")\n",
|
|
"src_path = random.choice(candidates)\n",
|
|
"\n",
|
|
"# 2) Load as 16 kHz mono float32\n",
|
|
"y, sr = librosa.load(src_path, sr=16000, mono=True)\n",
|
|
"y = y.astype(np.float32, copy=False)\n",
|
|
"\n",
|
|
"# 3) Augment — microwakeword Augmentation expects a 1-D numpy array\n",
|
|
"try:\n",
|
|
" y_aug = augmenter.augment_clip(y)\n",
|
|
"except Exception as e:\n",
|
|
" # some versions accept (samples, sr) — try that as a fallback\n",
|
|
" try:\n",
|
|
" y_aug = augmenter.augment_clip((y, sr))\n",
|
|
" except Exception:\n",
|
|
" raise\n",
|
|
"\n",
|
|
"# 4) Save and play\n",
|
|
"out_path = output_dir / \"augmented_clip.wav\"\n",
|
|
"sf.write(str(out_path), y_aug.astype(np.float32, copy=False), sr, subtype=\"PCM_16\")\n",
|
|
"print(f\"Augmented clip saved to {out_path}\")\n",
|
|
"display(Audio(str(out_path), autoplay=True))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "D7BHcY1mEGbK"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Augment samples and save the training, validation, and testing sets.\n",
|
|
"# This version avoids datasets.Audio entirely by driving Clips from local WAVs.\n",
|
|
"\n",
|
|
"import os, glob, random\n",
|
|
"from pathlib import Path\n",
|
|
"import types\n",
|
|
"import numpy as np\n",
|
|
"import librosa\n",
|
|
"from mmap_ninja.ragged import RaggedMmap\n",
|
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
|
"\n",
|
|
"# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n",
|
|
"def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n",
|
|
" \"\"\"\n",
|
|
" Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n",
|
|
" Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n",
|
|
" \"\"\"\n",
|
|
" files = sorted(glob.glob(\"generated_samples/*.wav\"))\n",
|
|
" if not files:\n",
|
|
" raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n",
|
|
"\n",
|
|
" rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n",
|
|
" files_shuf = files[:]\n",
|
|
" rng.shuffle(files_shuf)\n",
|
|
"\n",
|
|
" n = len(files_shuf)\n",
|
|
" n_val = max(1, int(0.10 * n))\n",
|
|
" n_test = max(1, int(0.10 * n))\n",
|
|
" n_train = max(0, n - n_val - n_test)\n",
|
|
" splits = {\n",
|
|
" \"train\": files_shuf[:n_train],\n",
|
|
" \"validation\": files_shuf[n_train:n_train + n_val],\n",
|
|
" \"test\": files_shuf[n_train + n_val:],\n",
|
|
" }\n",
|
|
" file_list = splits.get(split, [])\n",
|
|
" if not file_list:\n",
|
|
" return # nothing to yield\n",
|
|
"\n",
|
|
" for _ in range(max(1, int(repeat))):\n",
|
|
" for p in file_list:\n",
|
|
" y, sr = librosa.load(p, sr=16000, mono=True)\n",
|
|
" yield y.astype(np.float32, copy=False)\n",
|
|
"\n",
|
|
"# Bind the patched generator to your existing `clips` instance\n",
|
|
"clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n",
|
|
"print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n",
|
|
"\n",
|
|
"# ---- Validate augmentation asset folders exist ----\n",
|
|
"def validate(paths):\n",
|
|
" for p in paths:\n",
|
|
" if not Path(p).exists():\n",
|
|
" raise SystemExit(f\"❌ Missing directory: {p}. Run dataset prep first.\")\n",
|
|
"\n",
|
|
"impulse_paths = [\"mit_rirs\"]\n",
|
|
"background_paths = [\"fma_16k\", \"audioset_16k\"]\n",
|
|
"validate(impulse_paths + background_paths)\n",
|
|
"\n",
|
|
"# ---- Output root ----\n",
|
|
"out_root = Path(\"generated_augmented_features\")\n",
|
|
"out_root.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"# ---- Split config (same as before) ----\n",
|
|
"split_cfg = {\n",
|
|
" \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n",
|
|
" \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n",
|
|
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
|
"}\n",
|
|
"\n",
|
|
"# ---- Generate features ----\n",
|
|
"for split, cfg in split_cfg.items():\n",
|
|
" out_dir = out_root / split\n",
|
|
" out_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" print(f\"🧪 Processing {split} …\")\n",
|
|
"\n",
|
|
" spectros = SpectrogramGeneration(\n",
|
|
" clips=clips, # now backed by our WAV loader\n",
|
|
" augmenter=augmenter, # your existing augmenter\n",
|
|
" slide_frames=cfg[\"slide_frames\"],\n",
|
|
" step_ms=10,\n",
|
|
" )\n",
|
|
"\n",
|
|
" RaggedMmap.from_generator(\n",
|
|
" out_dir=str(out_dir / \"wakeword_mmap\"),\n",
|
|
" sample_generator=spectros.spectrogram_generator(\n",
|
|
" split=cfg[\"name\"], repeat=cfg[\"repetition\"]\n",
|
|
" ),\n",
|
|
" batch_size=100,\n",
|
|
" verbose=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
"print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "1pGuJDPyp3ax"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Downloads pre-generated spectrogram features (made for microWakeWord in\n",
|
|
"# particular) for various negative datasets. This can be slow!\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import requests\n",
|
|
"import zipfile\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"# Function to download a file with progress bar\n",
|
|
"def download_file(url, output_path):\n",
|
|
" response = requests.get(url, stream=True)\n",
|
|
" total_size = int(response.headers.get('content-length', 0))\n",
|
|
" with open(output_path, \"wb\") as f, tqdm(\n",
|
|
" desc=f\"Downloading {output_path.name}\",\n",
|
|
" total=total_size,\n",
|
|
" unit=\"B\",\n",
|
|
" unit_scale=True,\n",
|
|
" unit_divisor=1024,\n",
|
|
" ) as bar:\n",
|
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
|
" f.write(chunk)\n",
|
|
" bar.update(len(chunk))\n",
|
|
" print(f\"Downloaded: {output_path}\")\n",
|
|
"\n",
|
|
"# Function to extract ZIP files\n",
|
|
"def extract_zip(zip_path, extract_to):\n",
|
|
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
|
" zip_ref.extractall(extract_to)\n",
|
|
" print(f\"Extracted: {zip_path} to {extract_to}\")\n",
|
|
"\n",
|
|
"# Directory for negative datasets\n",
|
|
"output_dir = Path('./negative_datasets')\n",
|
|
"output_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"# Negative dataset URLs\n",
|
|
"link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n",
|
|
"filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n",
|
|
"\n",
|
|
"# Download and extract files\n",
|
|
"for fname in filenames:\n",
|
|
" link = link_root + fname\n",
|
|
" zip_path = output_dir / fname\n",
|
|
"\n",
|
|
" # Download only if the file doesn't already exist\n",
|
|
" if not zip_path.exists():\n",
|
|
" try:\n",
|
|
" download_file(link, zip_path)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error downloading {fname}: {e}\")\n",
|
|
" continue\n",
|
|
"\n",
|
|
" # Extract the ZIP file\n",
|
|
" try:\n",
|
|
" extract_zip(zip_path, output_dir)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error extracting {fname}: {e}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Ii1A14GsGVQT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save a yaml config that controls the training process\n",
|
|
"# These hyperparamters can make a huge different in model quality.\n",
|
|
"# Experiment with sampling and penalty weights and increasing the number of\n",
|
|
"# training steps.\n",
|
|
"\n",
|
|
"import yaml\n",
|
|
"import os\n",
|
|
"\n",
|
|
"config = {}\n",
|
|
"\n",
|
|
"config[\"window_step_ms\"] = 10\n",
|
|
"\n",
|
|
"config[\"train_dir\"] = \"trained_models/wakeword\"\n",
|
|
"\n",
|
|
"config[\"features\"] = [\n",
|
|
" {\n",
|
|
" \"features_dir\": \"generated_augmented_features\",\n",
|
|
" \"sampling_weight\": 2.0, # Increased\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": True,\n",
|
|
" \"truncation_strategy\": \"truncate_start\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/speech\",\n",
|
|
" \"sampling_weight\": 12.0, # Adjusted\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
|
" \"sampling_weight\": 12.0, # Adjusted\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
|
" \"sampling_weight\": 5.0, # Balanced\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
|
" \"sampling_weight\": 0.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"split\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
"]\n",
|
|
"\n",
|
|
"config[\"training_steps\"] = [40000] # Increased\n",
|
|
"config[\"positive_class_weight\"] = [1]\n",
|
|
"config[\"negative_class_weight\"] = [20] # Adjusted\n",
|
|
"config[\"learning_rates\"] = [0.001] # Adjusted\n",
|
|
"config[\"batch_size\"] = 128\n",
|
|
"\n",
|
|
"config[\"time_mask_max_size\"] = [0] # Enabled SpecAugment\n",
|
|
"config[\"time_mask_count\"] = [0]\n",
|
|
"config[\"freq_mask_max_size\"] = [0]\n",
|
|
"config[\"freq_mask_count\"] = [0]\n",
|
|
"\n",
|
|
"config[\"eval_step_interval\"] = 500 # Adjusted\n",
|
|
"config[\"clip_duration_ms\"] = 1500 # Increased\n",
|
|
"\n",
|
|
"config[\"target_minimization\"] = 0.9\n",
|
|
"config[\"minimization_metric\"] = None # Updated\n",
|
|
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
|
"\n",
|
|
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
|
" documents = yaml.dump(config, file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "WoEXJBaiC9mf"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Trains a model. When finished, it will quantize and convert the model to a\n",
|
|
"# streaming version suitable for on-device detection.\n",
|
|
"# It will resume if stopped, but it will start over at the configured training\n",
|
|
"# steps in the yaml file.\n",
|
|
"# Change --train 0 to only convert and test the best-weighted model.\n",
|
|
"# On Google colab, it doesn't print the mini-batch results, so it may appear\n",
|
|
"# stuck for several minutes! Additionally, it is very slow compared to training\n",
|
|
"# on a local GPU.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"\n",
|
|
"# Ensure the library path is correctly set\n",
|
|
"os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n",
|
|
"\n",
|
|
"# Training command with optimized settings\n",
|
|
"!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
|
|
"--training_config='training_parameters.yaml' \\\n",
|
|
"--train 1 \\\n",
|
|
"--restore_checkpoint 1 \\\n",
|
|
"--test_tf_nonstreaming 0 \\\n",
|
|
"--test_tflite_nonstreaming 0 \\\n",
|
|
"--test_tflite_nonstreaming_quantized 0 \\\n",
|
|
"--test_tflite_streaming 0 \\\n",
|
|
"--test_tflite_streaming_quantized 1 \\\n",
|
|
"--use_weights \"best_weights\" \\\n",
|
|
"mixednet \\\n",
|
|
"--pointwise_filters \"64,64,64,64\" \\\n",
|
|
"--repeat_in_block \"1,1,1,1\" \\\n",
|
|
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n",
|
|
"--residual_connection \"0,0,0,0\" \\\n",
|
|
"--first_conv_filters 32 \\\n",
|
|
"--first_conv_kernel_size 5 \\\n",
|
|
"--stride 2\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ex_UIWvwtjAN"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import shutil\n",
|
|
"import json\n",
|
|
"from IPython.display import FileLink\n",
|
|
"\n",
|
|
"# Define the source path and desired download location for the TFLite file\n",
|
|
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
|
"destination_path = \"./stream_state_internal_quant.tflite\"\n",
|
|
"\n",
|
|
"# Copy the TFLite file to the current working directory\n",
|
|
"shutil.copy(source_path, destination_path)\n",
|
|
"\n",
|
|
"# Define the JSON file content\n",
|
|
"json_data = {\n",
|
|
" \"type\": \"micro\",\n",
|
|
" \"wake_word\": \"hey_norman\", # Adjust this if the target_word changes dynamically\n",
|
|
" \"author\": \"master phooey\",\n",
|
|
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
|
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
|
" \"trained_languages\": [\"en\"],\n",
|
|
" \"version\": 2,\n",
|
|
" \"micro\": {\n",
|
|
" \"probability_cutoff\": 0.97,\n",
|
|
" \"sliding_window_size\": 5,\n",
|
|
" \"feature_step_size\": 10,\n",
|
|
" \"tensor_arena_size\": 30000,\n",
|
|
" \"minimum_esphome_version\": \"2024.7.0\"\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Define the JSON file path\n",
|
|
"json_path = \"./stream_state_internal_quant.json\"\n",
|
|
"\n",
|
|
"# Write the JSON file\n",
|
|
"with open(json_path, \"w\") as json_file:\n",
|
|
" json.dump(json_data, json_file, indent=2)\n",
|
|
"\n",
|
|
"# Generate download links for both files\n",
|
|
"print(\"Download your files:\")\n",
|
|
"print(\"TFLite Model:\")\n",
|
|
"display(FileLink(destination_path))\n",
|
|
"print(\"\\nJSON Metadata:\")\n",
|
|
"display(FileLink(json_path))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "T4",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|