{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🥔 MicroWakeWord Trainer — Tater Totterson Edition\n", "# ==================================================\n", "# Welcome, friend! 👋 This notebook will help you train your very own wake word model.\n", "# Think of it like teaching Tater Totterson to recognize when you say a special word.\n", "#\n", "# By the end, you'll have:\n", "# ✅ A trained TensorFlow Lite model ready for on-device detection.\n", "# ✅ A matching JSON manifest you can drop straight into ESPHome.\n", "#\n", "# This flow is optimized for Python 3.10 and NVIDIA GPUs (but should work elsewhere too).\n", "# You can customize the wake word, play with training parameters, and experiment with\n", "# different datasets until you get something that feels just right. 💪\n", "#\n", "# ⚡ Quick Tips:\n", "# • Change TARGET_WORD below to whatever you want your wake word to be.\n", "# • Rerun the notebook from the top if you change it (to regenerate everything).\n", "# • Expect to experiment — tweaking hyperparameters is part of the fun!\n", "#\n", "# When you’re done, you’ll get two files:\n", "# 1️⃣ .tflite — your trained model.\n", "# 2️⃣ .json — a manifest for ESPHome integration.\n", "#\n", "# More info & examples:\n", "# 🔗 https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker\n", "\n", "# --- Set your wake word here ---\n", "TARGET_WORD = \"tater\" # 🗣️ Change this to whatever phrase you want!\n", "print(f\"🥔 Tater Totterson is listening for: '{TARGET_WORD}'\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BFf6511E65ff" }, "outputs": [], "source": [ "import platform\n", "import sys\n", "import os\n", "\n", "# mac-only helper deps\n", "if platform.system() == \"Darwin\":\n", " !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n", "\n", "!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n", "\n", "# 👇 use the actual location in the container\n", "repo_path = \"/data/microWakeWord\"\n", "\n", "if not os.path.exists(repo_path):\n", " print(\"⬇️ Cloning microWakeWord repository to /data…\")\n", " !git clone https://github.com/TaterTotterson/micro-wake-word.git {repo_path}\n", "\n", "# optional: pin to a commit\n", "# !cd /data/microWakeWord && git checkout ac6502bf48b5e372c47ed509f5f5ca181e6d50bb\n", "\n", "if os.path.exists(repo_path):\n", " print(\"📦 Installing microWakeWord...\")\n", " !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n", "else:\n", " print(f\"❌ Repository not found at {repo_path}. Clone might have failed.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BFf6511E65ff" }, "outputs": [], "source": [ "# --- GPU Check (Torch + ONNX Runtime) ---\n", "\n", "import torch\n", "import onnxruntime as ort\n", "\n", "print(\"🔧 Torch CUDA Available:\", torch.cuda.is_available())\n", "if torch.cuda.is_available():\n", " print(\" • Device count:\", torch.cuda.device_count())\n", " print(\" • Current device:\", torch.cuda.current_device())\n", " print(\" • Device name:\", torch.cuda.get_device_name(torch.cuda.current_device()))\n", "else:\n", " print(\"⚠️ Torch cannot see a GPU — check Docker runtime (--gpus all) and nvidia-container-toolkit\")\n", "\n", "print(\"\\n🔧 ONNX Runtime Providers:\")\n", "try:\n", " providers = ort.get_available_providers()\n", " print(\" •\", providers)\n", " if \"CUDAExecutionProvider\" not in providers:\n", " print(\"⚠️ CUDAExecutionProvider not available — ONNX will fall back to CPU.\")\n", "except Exception as e:\n", " print(\"⚠️ Could not query ONNX Runtime providers:\", e)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dEluu7nL7ywd" }, "outputs": [], "source": [ "# NVIDIA Linux Docker: generate 1 sample of the target word (robust + CUDA check)\n", "\n", "import os, sys, shutil, subprocess, time, platform\n", "from pathlib import Path\n", "from IPython.display import Audio, display\n", "\n", "REPO_URL = \"https://github.com/rhasspy/piper-sample-generator\"\n", "REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n", "MODELS_DIR = REPO_DIR / \"models\"\n", "MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n", "MODEL_URL = f\"https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/{MODEL_NAME}\"\n", "AUDIO_OUT_DIR = Path.cwd() / \"generated_samples\"\n", "AUDIO_PATH = AUDIO_OUT_DIR / \"0.wav\"\n", "\n", "def run(cmd, check=True):\n", " print(\"→\", \" \".join(cmd))\n", " proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n", " for line in proc.stdout:\n", " print(line, end=\"\")\n", " rc = proc.wait()\n", " if check and rc != 0:\n", " raise RuntimeError(f\"Command failed with exit code {rc}: {' '.join(cmd)}\")\n", " return rc\n", "\n", "def pip_install(*pkgs):\n", " run([sys.executable, \"-m\", \"pip\", \"install\", \"--upgrade\", \"pip\"], check=False)\n", " run([sys.executable, \"-m\", \"pip\", \"install\", *pkgs])\n", "\n", "def safe_clone(repo_url, branch=None, dest=REPO_DIR, retries=2):\n", " if dest.exists() and not (dest / \".git\").exists():\n", " print(\"⚠️ Found partial clone. Removing…\")\n", " shutil.rmtree(dest, ignore_errors=True)\n", " if not dest.exists():\n", " for i in range(retries + 1):\n", " try:\n", " cmd = [\"git\", \"clone\", \"--depth\", \"1\", repo_url, str(dest)]\n", " if branch:\n", " cmd = [\"git\", \"clone\", \"--depth\", \"1\", \"--branch\", branch, repo_url, str(dest)]\n", " run(cmd)\n", " break\n", " except Exception as e:\n", " if i == retries:\n", " raise\n", " print(f\"Clone failed ({i+1}/{retries+1}). Retrying in 2s… [{e}]\")\n", " time.sleep(2)\n", "\n", "def ensure_model():\n", " MODELS_DIR.mkdir(parents=True, exist_ok=True)\n", " mp = MODELS_DIR / MODEL_NAME\n", " if not mp.exists() or mp.stat().st_size == 0:\n", " import urllib.request\n", " print(f\"Downloading model to {mp} …\")\n", " with urllib.request.urlopen(MODEL_URL) as r, open(mp, \"wb\") as f:\n", " shutil.copyfileobj(r, f)\n", " if mp.stat().st_size < 100 * 1024:\n", " raise RuntimeError(\"Downloaded model looks too small; download may have failed.\")\n", " print(f\"✅ Model ready: {mp}\")\n", "\n", "# 1) Clone main repo (Linux/NVIDIA)\n", "print(\"Linux/NVIDIA detected — using main piper-sample-generator repo.\")\n", "safe_clone(REPO_URL)\n", "\n", "# 2) Install deps\n", "# - piper-tts provides the `piper` module (required by generate_samples.py)\n", "# - piper-phonemize-cross does the phonemization\n", "# - onnxruntime-gpu enables CUDA (container must have NVIDIA runtime)\n", "deps = [\n", " \"piper-tts>=1.2.0\",\n", " \"piper-phonemize-cross==1.2.1\",\n", " \"soundfile\",\n", " \"numpy\",\n", " \"onnxruntime-gpu>=1.16.0\",\n", "]\n", "pip_install(*deps)\n", "\n", "# 3) Verify CUDA provider is available\n", "try:\n", " import onnxruntime as ort\n", " providers = ort.get_available_providers()\n", " print(f\"ONNX Runtime providers: {providers}\")\n", " if \"CUDAExecutionProvider\" not in providers:\n", " print(\"⚠️ CUDAExecutionProvider not available. \"\n", " \"The sample will still run on CPU, but check your NVIDIA container setup \"\n", " \"(nvidia-container-toolkit, runtime, and driver).\")\n", "except Exception as e:\n", " print(\"⚠️ Could not import onnxruntime to verify providers:\", e)\n", "\n", "# 4) Ensure model present\n", "ensure_model()\n", "\n", "# 5) Generate one sample\n", "AUDIO_OUT_DIR.mkdir(parents=True, exist_ok=True)\n", "gen_script = REPO_DIR / \"generate_samples.py\"\n", "if not gen_script.exists():\n", " raise FileNotFoundError(f\"Missing generator: {gen_script}\")\n", "\n", "cmd = [\n", " sys.executable, str(gen_script),\n", " TARGET_WORD,\n", " \"--model\", str(MODELS_DIR / MODEL_NAME), # ← pass the generator .pt explicitly\n", " \"--max-samples\", \"1\",\n", " \"--batch-size\", \"1\",\n", " \"--output-dir\", str(AUDIO_OUT_DIR),\n", "]\n", "run(cmd)\n", "\n", "# 6) Play the audio (if the notebook frontend supports it)\n", "if AUDIO_PATH.exists():\n", " print(f\"🎧 Playing {AUDIO_PATH}\")\n", " display(Audio(str(AUDIO_PATH), autoplay=True))\n", "else:\n", " print(f\"Audio file not found at {AUDIO_PATH}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-SvGtCCM9akR" }, "outputs": [], "source": [ "# Generate a large number of wake word samples for training (with length-scale sweep)\n", "import sys, subprocess\n", "from pathlib import Path\n", "\n", "REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n", "MODELS_DIR = REPO_DIR / \"models\"\n", "MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n", "\n", "MAX_SAMPLES = 50000\n", "BATCH_SIZE = 100\n", "\n", "# Piper \"speed\" control via piper-sample-generator is length_scale(s)\n", "LENGTH_SCALES = [\"0.85\", \"0.95\", \"1.00\", \"1.05\", \"1.15\"]\n", "\n", "cmd = [\n", " sys.executable,\n", " str(REPO_DIR / \"generate_samples.py\"),\n", " TARGET_WORD,\n", " \"--model\", str(MODELS_DIR / MODEL_NAME),\n", " \"--max-samples\", str(MAX_SAMPLES),\n", " \"--batch-size\", str(BATCH_SIZE),\n", " \"--output-dir\", \"generated_samples\",\n", " \"--length-scales\", *LENGTH_SCALES,\n", "]\n", "\n", "print(\"→\", \" \".join(cmd))\n", "subprocess.run(cmd, check=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YJRG4Qvo9nXG" }, "outputs": [], "source": [ "# NVIDIA/Linux dataset prep to match the Apple behavior, but with pinned AudioSet\n", "# MIT RIR -> resample to 16 kHz\n", "# AudioSet -> fetch from a working HF revision, convert to 16 kHz mono, skip bad\n", "# FMA -> resample to 16 kHz mono\n", "\n", "import os, sys, subprocess, scipy.io.wavfile, numpy as np\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "import soundfile as sf\n", "import librosa\n", "from datasets import load_dataset\n", "\n", "# -------------------------------------------------\n", "# small shell helpers (for curl/tar probing)\n", "# -------------------------------------------------\n", "def sh(cmd: str) -> int:\n", " return subprocess.call(cmd, shell=True)\n", "\n", "def curl(url: str, out: Path) -> int:\n", " # -L follow, -s silent, --fail to get nonzero on 404\n", " return subprocess.call(f\"curl -L -s --fail '{url}' -o '{out}'\", shell=True)\n", "\n", "def write_wav(dst: Path, data: np.ndarray, sr: int):\n", " x = np.clip(data, -1.0, 1.0)\n", " scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))\n", "\n", "# -----------------------------\n", "# MIT RIR (resample to 16 kHz)\n", "# -----------------------------\n", "print(\"=== MIT RIR ===\")\n", "rir_out = Path(\"mit_rirs\")\n", "rir_out.mkdir(exist_ok=True)\n", "if not any(rir_out.rglob(\"*.wav\")):\n", " ok = 0\n", " try:\n", " # Avoid datasets.Audio to keep TorchCodec out:\n", " # Use streaming=True + manual decode with librosa\n", " print(\"⬇️ MIT RIR (streaming + manual decode)…\")\n", " ds = load_dataset(\n", " \"davidscripka/MIT_environmental_impulse_responses\",\n", " split=\"train\",\n", " streaming=True\n", " )\n", " for i, row in enumerate(tqdm(ds)):\n", " try:\n", " audio_path = row[\"audio\"][\"path\"]\n", " y, sr = librosa.load(audio_path, sr=16000, mono=True)\n", " write_wav(rir_out / f\"rir_{i:04d}.wav\", y, 16000)\n", " ok += 1\n", " except Exception:\n", " pass\n", " print(f\"✅ MIT RIR saved: {ok} files\")\n", " except Exception as e:\n", " print(f\"⚠️ MIT RIR download failed: {e}\")\n", " # Fallback ZIP route\n", " try:\n", " print(\"⬇️ MIT RIR (fallback ZIP)…\")\n", " zip_url = \"https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip\"\n", " zip_path = rir_out.parent / \"MIT_RIR_Audio.zip\"\n", " if not zip_path.exists():\n", " os.system(f\"wget -q -O '{zip_path}' '{zip_url}'\")\n", " os.system(f'unzip -q -o \"{zip_path}\" -d \"{rir_out}\"')\n", " # Normalize to 16k mono\n", " for p in tqdm(list(rir_out.rglob(\"*.wav\")), desc=\"Normalize MIT RIR\"):\n", " a, sr = sf.read(p, always_2d=False)\n", " if a.ndim > 1:\n", " a = a[:, 0]\n", " if sr != 16000:\n", " a, _ = librosa.load(p, sr=16000, mono=True)\n", " write_wav(p, a, 16000)\n", " print(\"✅ MIT RIR fallback complete\")\n", " except Exception as e2:\n", " print(f\"❌ MIT RIR fallback failed: {e2}\")\n", "else:\n", " print(\"✅ mit_rirs exists; skipping.\")\n", "\n", "# ============================================================\n", "# AudioSet (pinned FLAC .tar → 16k mono, skip bad files)\n", "# ============================================================\n", "print(\"\\n=== AudioSet subset (pinned FLAC .tar → 16k mono) ===\")\n", "audioset_dir = Path(\"audioset\"); audioset_dir.mkdir(exist_ok=True)\n", "audioset_out = Path(\"audioset_16k\"); audioset_out.mkdir(exist_ok=True)\n", "\n", "if any(audioset_out.rglob(\"*.wav\")):\n", " print(\"✅ audioset_16k exists; skipping.\")\n", "else:\n", " # commits / refs we know about — we’ll probe them\n", " REV_CANDIDATES = [\n", " \"6762f044d1c88619c7f2006486036192128fb07e\",\n", " \"0049167e89f259a010c3f070fe3666d9e5242836\",\n", " \"ceb9eaaa7844c9ad7351e659c84a572e376ad06d\",\n", " \"main\", # last resort\n", " ]\n", " # possible folder layouts\n", " TAR_PATTERNS = [\n", " \"data/bal_train0{idx}.tar\",\n", " \"data/bal_train/bal_train0{idx}.tar\",\n", " ]\n", "\n", " def find_working_rev():\n", " for rev in REV_CANDIDATES:\n", " for pat in TAR_PATTERNS:\n", " probe = f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/{rev}/{pat.format(idx=0)}\"\n", " rc = sh(f\"curl -I -L --fail -s '{probe}' > /dev/null\")\n", " if rc == 0:\n", " return rev, pat\n", " return None, None\n", "\n", " rev, pattern = find_working_rev()\n", " if rev is None:\n", " raise RuntimeError(\"Could not locate an AudioSet revision with FLAC tarballs still present on HF.\")\n", "\n", " print(f\"📌 Using AudioSet revision: {rev}\")\n", " print(f\"🗂️ Tar layout pattern: {pattern}\")\n", "\n", " # download + extract bal_train00..09\n", " for i in range(10):\n", " rel = pattern.format(idx=i)\n", " url = f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/{rev}/{rel}\"\n", " fname = rel.split(\"/\")[-1]\n", " out_tar = audioset_dir / fname\n", " if not out_tar.exists():\n", " print(f\"⬇️ {fname}\")\n", " rc = curl(url, out_tar)\n", " if rc != 0:\n", " print(f\"⚠️ Could not fetch {fname} at rev {rev}; continuing.\")\n", " continue\n", " print(f\"📦 Extract {fname}\")\n", " rc = sh(f\"tar -xf '{out_tar}' -C '{audioset_dir}'\")\n", " if rc != 0:\n", " print(f\"⚠️ tar extract failed for {fname}; continuing.\")\n", "\n", " # convert FLAC → 16k mono WAV\n", " flacs = list(audioset_dir.rglob(\"*.flac\"))\n", " print(f\"🔎 FLAC files: {len(flacs)}\")\n", " audioset_bad = []\n", " ok = 0\n", " for p in tqdm(flacs, desc=\"AudioSet→WAV (resample 16k mono)\"):\n", " try:\n", " y, _ = librosa.load(p, sr=16000, mono=True)\n", " if y.size == 0:\n", " raise ValueError(\"empty audio\")\n", " write_wav(audioset_out / (p.stem + \".wav\"), y, 16000)\n", " ok += 1\n", " except Exception as e:\n", " audioset_bad.append(f\"{p}:{e}\")\n", "\n", " if audioset_bad:\n", " (audioset_out / \"audioset_corrupted_files.log\").write_text(\"\\n\".join(audioset_bad))\n", " print(f\"✅ AudioSet complete ({ok} ok, {len(audioset_bad)} failed)\")\n", "\n", "# -----------------------------\n", "# FMA xsmall (resample to 16 kHz mono)\n", "# -----------------------------\n", "print(\"\\n=== FMA xsmall ===\")\n", "fma_zip_dir = Path(\"fma\"); fma_zip_dir.mkdir(exist_ok=True)\n", "fma_out = Path(\"fma_16k\"); fma_out.mkdir(exist_ok=True)\n", "\n", "zipname = \"fma_xs.zip\"\n", "zipurl = f\"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{zipname}\"\n", "zipout = fma_zip_dir / zipname\n", "if not zipout.exists():\n", " os.system(f\"wget -q -O '{zipout}' '{zipurl}'\")\n", " os.system(f\"cd fma && unzip -q '{zipname}'\")\n", "\n", "mp3s = list(Path(\"fma/fma_small\").rglob(\"*.mp3\"))\n", "print(f\"🎵 FMA mp3 count: {len(mp3s)}\")\n", "corrupt = []\n", "for p in tqdm(mp3s, desc=\"FMA→16k WAV\"):\n", " try:\n", " y, sr = librosa.load(p, sr=16000, mono=True)\n", " if y.size == 0:\n", " raise ValueError(\"empty audio\")\n", " write_wav(fma_out / (p.stem + \".wav\"), y, 16000)\n", " except Exception as e:\n", " corrupt.append(f\"{p}:{e}\")\n", "if corrupt:\n", " Path(\"fma_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n", "\n", "print(\"\\n✅ Dataset prep complete!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XW3bmbI5-JAz" }, "outputs": [], "source": [ "# Sets up the augmentations.\n", "# To improve your model, experiment with these settings and use more sources of\n", "# background clips.\n", "import sys, os\n", "\n", "# try the common places we’ve used\n", "candidates = [\n", " \"/data/microWakeWord\", # what the last install log showed\n", " \"/data/microwakeword\", # lowercase variant\n", " \"./microwakeword\", # local clone\n", " \"./microWakeWord\", # camel case\n", "]\n", "\n", "for base in candidates:\n", " if os.path.isdir(base):\n", " # add the repo root\n", " sys.path.insert(0, base)\n", " # add the actual package dir inside the repo\n", " if os.path.isdir(os.path.join(base, \"microwakeword\")):\n", " sys.path.insert(0, os.path.join(base, \"microwakeword\"))\n", " break\n", "from microwakeword.audio.augmentation import Augmentation\n", "from microwakeword.audio.clips import Clips\n", "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "\n", "def validate_directories(paths):\n", " for path in paths:\n", " if not os.path.exists(path):\n", " print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n", " return False\n", " return True\n", "\n", "# Paths to augmented data\n", "impulse_paths = ['mit_rirs']\n", "background_paths = ['fma_16k', 'audioset_16k']\n", "\n", "if not validate_directories(impulse_paths + background_paths):\n", " raise ValueError(\"One or more required directories are missing.\")\n", "\n", "clips = Clips(\n", " input_directory='./generated_samples',\n", " file_pattern='*.wav',\n", " max_clip_duration_s=5,\n", " remove_silence=True,\n", " random_split_seed=10,\n", " split_count=0.1,\n", ")\n", "\n", "augmenter = Augmentation(\n", " augmentation_duration_s=3.2,\n", " augmentation_probabilities={\n", " \"SevenBandParametricEQ\": 0.1,\n", " \"TanhDistortion\": 0.05,\n", " \"PitchShift\": 0.15,\n", " \"BandStopFilter\": 0.1,\n", " \"AddColorNoise\": 0.1,\n", " \"AddBackgroundNoise\": 0.7,\n", " \"Gain\": 0.8,\n", " \"RIR\": 0.7,\n", " },\n", " impulse_paths=impulse_paths,\n", " background_paths=background_paths,\n", " background_min_snr_db=5,\n", " background_max_snr_db=10,\n", " min_jitter_s=0.2,\n", " max_jitter_s=0.3,\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V5UsJfKKD1k9" }, "outputs": [], "source": [ "# Augment a random generated-sample WAV and play it back (pass ndarray to augmenter)\n", "from pathlib import Path\n", "from IPython.display import Audio, display\n", "import numpy as np\n", "import soundfile as sf\n", "import librosa, random, glob\n", "\n", "output_dir = Path(\"./augmented_clips\")\n", "output_dir.mkdir(exist_ok=True)\n", "\n", "# 1) Pick a random WAV from the Piper outputs\n", "candidates = glob.glob(\"generated_samples/*.wav\")\n", "if not candidates:\n", " raise SystemExit(\"No files in generated_samples/. Run the TTS sample cell first.\")\n", "src_path = random.choice(candidates)\n", "\n", "# 2) Load as 16 kHz mono float32\n", "y, sr = librosa.load(src_path, sr=16000, mono=True)\n", "y = y.astype(np.float32, copy=False)\n", "\n", "# 3) Augment — microwakeword Augmentation expects a 1-D numpy array\n", "try:\n", " y_aug = augmenter.augment_clip(y)\n", "except Exception as e:\n", " # some versions accept (samples, sr) — try that as a fallback\n", " try:\n", " y_aug = augmenter.augment_clip((y, sr))\n", " except Exception:\n", " raise\n", "\n", "# 4) Save and play\n", "out_path = output_dir / \"augmented_clip.wav\"\n", "sf.write(str(out_path), y_aug.astype(np.float32, copy=False), sr, subtype=\"PCM_16\")\n", "print(f\"Augmented clip saved to {out_path}\")\n", "display(Audio(str(out_path), autoplay=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D7BHcY1mEGbK" }, "outputs": [], "source": [ "# Augment samples and save the training, validation, and testing sets.\n", "# This version avoids datasets.Audio entirely by driving Clips from local WAVs.\n", "\n", "import os, glob, random\n", "from pathlib import Path\n", "import types\n", "import numpy as np\n", "import librosa\n", "from mmap_ninja.ragged import RaggedMmap\n", "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "\n", "# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n", "def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n", " \"\"\"\n", " Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n", " Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n", " \"\"\"\n", " files = sorted(glob.glob(\"generated_samples/*.wav\"))\n", " if not files:\n", " raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n", "\n", " rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n", " files_shuf = files[:]\n", " rng.shuffle(files_shuf)\n", "\n", " n = len(files_shuf)\n", " n_val = max(1, int(0.10 * n))\n", " n_test = max(1, int(0.10 * n))\n", " n_train = max(0, n - n_val - n_test)\n", " splits = {\n", " \"train\": files_shuf[:n_train],\n", " \"validation\": files_shuf[n_train:n_train + n_val],\n", " \"test\": files_shuf[n_train + n_val:],\n", " }\n", " file_list = splits.get(split, [])\n", " if not file_list:\n", " return # nothing to yield\n", "\n", " for _ in range(max(1, int(repeat))):\n", " for p in file_list:\n", " y, sr = librosa.load(p, sr=16000, mono=True)\n", " yield y.astype(np.float32, copy=False)\n", "\n", "# Bind the patched generator to your existing `clips` instance\n", "clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n", "print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n", "\n", "# ---- Validate augmentation asset folders exist ----\n", "def validate(paths):\n", " for p in paths:\n", " if not Path(p).exists():\n", " raise SystemExit(f\"❌ Missing directory: {p}. Run dataset prep first.\")\n", "\n", "impulse_paths = [\"mit_rirs\"]\n", "background_paths = [\"fma_16k\", \"audioset_16k\"]\n", "validate(impulse_paths + background_paths)\n", "\n", "# ---- Output root ----\n", "out_root = Path(\"generated_augmented_features\")\n", "out_root.mkdir(exist_ok=True)\n", "\n", "# ---- Split config (same as before) ----\n", "split_cfg = {\n", " \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n", " \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n", " \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n", "}\n", "\n", "# ---- Generate features ----\n", "for split, cfg in split_cfg.items():\n", " out_dir = out_root / split\n", " out_dir.mkdir(parents=True, exist_ok=True)\n", " print(f\"🧪 Processing {split} …\")\n", "\n", " spectros = SpectrogramGeneration(\n", " clips=clips, # now backed by our WAV loader\n", " augmenter=augmenter, # your existing augmenter\n", " slide_frames=cfg[\"slide_frames\"],\n", " step_ms=10,\n", " )\n", "\n", " RaggedMmap.from_generator(\n", " out_dir=str(out_dir / \"wakeword_mmap\"),\n", " sample_generator=spectros.spectrogram_generator(\n", " split=cfg[\"name\"], repeat=cfg[\"repetition\"]\n", " ),\n", " batch_size=100,\n", " verbose=True,\n", " )\n", "\n", "print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1pGuJDPyp3ax" }, "outputs": [], "source": [ "# Downloads pre-generated spectrogram features (made for microWakeWord in\n", "# particular) for various negative datasets. This can be slow!\n", "\n", "import os\n", "import requests\n", "import zipfile\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "\n", "# Function to download a file with progress bar\n", "def download_file(url, output_path):\n", " response = requests.get(url, stream=True)\n", " total_size = int(response.headers.get('content-length', 0))\n", " with open(output_path, \"wb\") as f, tqdm(\n", " desc=f\"Downloading {output_path.name}\",\n", " total=total_size,\n", " unit=\"B\",\n", " unit_scale=True,\n", " unit_divisor=1024,\n", " ) as bar:\n", " for chunk in response.iter_content(chunk_size=1024):\n", " f.write(chunk)\n", " bar.update(len(chunk))\n", " print(f\"Downloaded: {output_path}\")\n", "\n", "# Function to extract ZIP files\n", "def extract_zip(zip_path, extract_to):\n", " with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n", " zip_ref.extractall(extract_to)\n", " print(f\"Extracted: {zip_path} to {extract_to}\")\n", "\n", "# Directory for negative datasets\n", "output_dir = Path('./negative_datasets')\n", "output_dir.mkdir(exist_ok=True)\n", "\n", "# Negative dataset URLs\n", "link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n", "filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n", "\n", "# Download and extract files\n", "for fname in filenames:\n", " link = link_root + fname\n", " zip_path = output_dir / fname\n", "\n", " # Download only if the file doesn't already exist\n", " if not zip_path.exists():\n", " try:\n", " download_file(link, zip_path)\n", " except Exception as e:\n", " print(f\"Error downloading {fname}: {e}\")\n", " continue\n", "\n", " # Extract the ZIP file\n", " try:\n", " extract_zip(zip_path, output_dir)\n", " except Exception as e:\n", " print(f\"Error extracting {fname}: {e}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ii1A14GsGVQT" }, "outputs": [], "source": [ "# --- Save a yaml config that controls the training process ---\n", "\n", "import os, sys, yaml\n", "\n", "config = {}\n", "\n", "config[\"window_step_ms\"] = 10\n", "config[\"train_dir\"] = \"trained_models/wakeword\"\n", "\n", "config[\"features\"] = [\n", " {\"features_dir\":\"generated_augmented_features\",\"sampling_weight\":2.0,\"penalty_weight\":1.0,\"truth\":True,\"truncation_strategy\":\"truncate_start\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/speech\",\"sampling_weight\":12.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/dinner_party\",\"sampling_weight\":12.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/no_speech\",\"sampling_weight\":5.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n", "]\n", "\n", "config[\"training_steps\"] = [40000]\n", "config[\"positive_class_weight\"] = [1]\n", "config[\"negative_class_weight\"] = [20]\n", "config[\"learning_rates\"] = [0.001]\n", "\n", "# Smaller batch to avoid GPU copy/alloc failures on 3070 laptop VRAM\n", "config[\"batch_size\"] = 16\n", "\n", "# SpecAugment off (as before)\n", "config[\"time_mask_max_size\"] = [0]\n", "config[\"time_mask_count\"] = [0]\n", "config[\"freq_mask_max_size\"] = [0]\n", "config[\"freq_mask_count\"] = [0]\n", "\n", "config[\"eval_step_interval\"] = 500\n", "config[\"clip_duration_ms\"] = 1500\n", "config[\"target_minimization\"] = 0.9\n", "config[\"minimization_metric\"] = None\n", "config[\"maximization_metric\"] = \"average_viable_recall\"\n", "\n", "with open(\"training_parameters.yaml\", \"w\") as f:\n", " yaml.dump(config, f)\n", "\n", "print(\"✅ Wrote training_parameters.yaml (batch_size=16)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WoEXJBaiC9mf" }, "outputs": [], "source": [ "# Train + export (GPU-friendly env + stable flags)\n", "\n", "import os, sys, gc, runpy\n", "\n", "if \"tensorflow\" not in sys.modules:\n", " os.environ[\"TF_FORCE_GPU_ALLOW_GROWTH\"] = \"true\" # grow as needed\n", " os.environ[\"TF_GPU_ALLOCATOR\"] = \"cuda_malloc_async\" # modern CUDA allocator\n", " os.environ[\"TF_XLA_FLAGS\"] = \"--tf_xla_auto_jit=0\" # disable XLA JIT (more stable mem)\n", " os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\" # quieter logs\n", " os.environ[\"NVIDIA_TF32_OVERRIDE\"] = \"1\" # allow TF32 (perf/VRAM win on Ampere+)\n", "# If you still hit GPU memory errors, uncomment to force a smaller workspace:\n", "# os.environ[\"TF_CUDNN_WORKSPACE_LIMIT_IN_MB\"] = \"256\"\n", "\n", "import tensorflow as tf\n", "\n", "allow_growth = \"\"\n", "# Per-device memory growth (belt + suspenders)\n", "for g in tf.config.list_physical_devices(\"GPU\"):\n", " try:\n", " tf.config.experimental.set_memory_growth(g, True)\n", " allow_growth = \"gpu_allow_growth, \"\n", " except Exception:\n", " pass\n", "print(\"GPUs:\", tf.config.list_physical_devices(\"GPU\"))\n", "gc.collect()\n", "\n", "print(f\"✅ Set environment with {allow_growth}cuda_malloc_async, xla_auto_jit=0, min_log_level=2, nvidia_tf2_override\")\n", "print(\" Starting training...\")\n", "\n", "original_argv = list(sys.argv)\n", "try:\n", " sys.argv = [\n", " 'model_train_eval.py',\n", " '--training_config', 'training_parameters.yaml',\n", " '--train', '1',\n", " '--restore_checkpoint', '1',\n", " '--test_tf_nonstreaming', '0',\n", " '--test_tflite_nonstreaming', '0',\n", " '--test_tflite_nonstreaming_quantized', '0',\n", " '--test_tflite_streaming', '0',\n", " '--test_tflite_streaming_quantized', '1',\n", " '--use_weights', 'best_weights',\n", " 'mixednet',\n", " '--pointwise_filters', '64,64,64,64',\n", " '--repeat_in_block', '1,1,1,1',\n", " '--mixconv_kernel_sizes', '[5], [7,11], [9,15], [23]',\n", " '--residual_connection', '0,0,0,0',\n", " '--first_conv_filters', '32',\n", " '--first_conv_kernel_size', '5',\n", " '--stride', '2'\n", " ]\n", " runpy.run_module(\"microwakeword.model_train_eval\", run_name=\"__main__\", alter_sys=True)\n", "finally:\n", " sys.argv = original_argv\n", "print(\"✅ Training and testing complete.\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ex_UIWvwtjAN" }, "outputs": [], "source": [ "import shutil\n", "import json\n", "from IPython.display import display, HTML\n", "\n", "# Use the wake word from Cell 3\n", "wake_word = TARGET_WORD\n", "\n", "# --- Copy TFLite file to working dir with wake word name ---\n", "source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n", "tflite_filename = f\"{wake_word}.tflite\"\n", "tflite_path = f\"./{tflite_filename}\"\n", "shutil.copy(source_path, tflite_path)\n", "\n", "# --- Write JSON metadata file with matching model name ---\n", "json_data = {\n", " \"type\": \"micro\",\n", " \"wake_word\": wake_word,\n", " \"author\": \"Tater Totterson\",\n", " \"website\": \"https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git\",\n", " \"model\": tflite_filename,\n", " \"trained_languages\": [\"en\"],\n", " \"version\": 2,\n", " \"micro\": {\n", " \"probability_cutoff\": 0.97,\n", " \"sliding_window_size\": 5,\n", " \"feature_step_size\": 10,\n", " \"tensor_arena_size\": 30000,\n", " \"minimum_esphome_version\": \"2024.7.0\"\n", " }\n", "}\n", "json_filename = f\"{wake_word}.json\"\n", "json_path = f\"./{json_filename}\"\n", "with open(json_path, \"w\") as json_file:\n", " json.dump(json_data, json_file, indent=2)\n", "\n", "# --- Display nice download links ---\n", "html = f\"\"\"\n", "

Download your files:

\n", "\n", "\"\"\"\n", "display(HTML(html))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }