{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🥔 MicroWakeWord Trainer — Tater Totterson Edition\n", "# ==================================================\n", "# Welcome, friend! 👋 This notebook will help you train your very own wake word model.\n", "# Think of it like teaching Tater Totterson to recognize when you say a special word.\n", "#\n", "# By the end, you'll have:\n", "# ✅ A trained TensorFlow Lite model ready for on-device detection.\n", "# ✅ A matching JSON manifest you can drop straight into ESPHome.\n", "#\n", "# This flow is optimized for Python 3.10 and NVIDIA GPUs (but should work elsewhere too).\n", "# You can customize the wake word, play with training parameters, and experiment with\n", "# different datasets until you get something that feels just right. 💪\n", "#\n", "# ⚡ Quick Tips:\n", "# • Change TARGET_WORD below to whatever you want your wake word to be.\n", "# • Rerun the notebook from the top if you change it (to regenerate everything).\n", "# • Expect to experiment — tweaking hyperparameters is part of the fun!\n", "#\n", "# When you’re done, you’ll get two files:\n", "# 1️⃣ .tflite — your trained model.\n", "# 2️⃣ .json — a manifest for ESPHome integration.\n", "#\n", "# More info & examples:\n", "# 🔗 https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker\n", "\n", "# --- Set your wake word here ---\n", "TARGET_WORD = \"hey_tater\" # 🗣️ Change this to whatever phrase you want!\n", "print(f\"🥔 Tater Totterson is listening for: '{TARGET_WORD}'\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BFf6511E65ff" }, "outputs": [], "source": [ "# Installs microWakeWord. Be sure to restart the session after this is finished.\n", "import platform\n", "import sys\n", "import os\n", "\n", "if platform.system() == \"Darwin\":\n", " # `pymicro-features` is installed from a fork to support building on macOS\n", " !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n", "\n", "# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n", "!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n", "\n", "# Clone the microWakeWord repository\n", "repo_path = \"./microWakeWord\"\n", "if not os.path.exists(repo_path):\n", " print(\"Cloning microWakeWord repository...\")\n", " !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n", "\n", "# Ensure the repository exists before attempting to install\n", "if os.path.exists(repo_path):\n", " print(\"Installing microWakeWord...\")\n", " !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n", "else:\n", " print(f\"Repository not found at {repo_path}. Cloning might have failed.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BFf6511E65ff" }, "outputs": [], "source": [ "# --- GPU Check (Torch + ONNX Runtime) ---\n", "\n", "import torch\n", "import onnxruntime as ort\n", "\n", "print(\"🔧 Torch CUDA Available:\", torch.cuda.is_available())\n", "if torch.cuda.is_available():\n", " print(\" • Device count:\", torch.cuda.device_count())\n", " print(\" • Current device:\", torch.cuda.current_device())\n", " print(\" • Device name:\", torch.cuda.get_device_name(torch.cuda.current_device()))\n", "else:\n", " print(\"⚠️ Torch cannot see a GPU — check Docker runtime (--gpus all) and nvidia-container-toolkit\")\n", "\n", "print(\"\\n🔧 ONNX Runtime Providers:\")\n", "try:\n", " providers = ort.get_available_providers()\n", " print(\" •\", providers)\n", " if \"CUDAExecutionProvider\" not in providers:\n", " print(\"⚠️ CUDAExecutionProvider not available — ONNX will fall back to CPU.\")\n", "except Exception as e:\n", " print(\"⚠️ Could not query ONNX Runtime providers:\", e)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dEluu7nL7ywd" }, "outputs": [], "source": [ "# NVIDIA Linux Docker: generate 1 sample of the target word (robust + CUDA check)\n", "\n", "import os, sys, shutil, subprocess, time, platform\n", "from pathlib import Path\n", "from IPython.display import Audio, display\n", "\n", "REPO_URL = \"https://github.com/rhasspy/piper-sample-generator\"\n", "REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n", "MODELS_DIR = REPO_DIR / \"models\"\n", "MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n", "MODEL_URL = f\"https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/{MODEL_NAME}\"\n", "AUDIO_OUT_DIR = Path.cwd() / \"generated_samples\"\n", "AUDIO_PATH = AUDIO_OUT_DIR / \"0.wav\"\n", "\n", "def run(cmd, check=True):\n", " print(\"→\", \" \".join(cmd))\n", " proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n", " for line in proc.stdout:\n", " print(line, end=\"\")\n", " rc = proc.wait()\n", " if check and rc != 0:\n", " raise RuntimeError(f\"Command failed with exit code {rc}: {' '.join(cmd)}\")\n", " return rc\n", "\n", "def pip_install(*pkgs):\n", " run([sys.executable, \"-m\", \"pip\", \"install\", \"--upgrade\", \"pip\"], check=False)\n", " run([sys.executable, \"-m\", \"pip\", \"install\", *pkgs])\n", "\n", "def safe_clone(repo_url, branch=None, dest=REPO_DIR, retries=2):\n", " if dest.exists() and not (dest / \".git\").exists():\n", " print(\"⚠️ Found partial clone. Removing…\")\n", " shutil.rmtree(dest, ignore_errors=True)\n", " if not dest.exists():\n", " for i in range(retries + 1):\n", " try:\n", " cmd = [\"git\", \"clone\", \"--depth\", \"1\", repo_url, str(dest)]\n", " if branch:\n", " cmd = [\"git\", \"clone\", \"--depth\", \"1\", \"--branch\", branch, repo_url, str(dest)]\n", " run(cmd)\n", " break\n", " except Exception as e:\n", " if i == retries:\n", " raise\n", " print(f\"Clone failed ({i+1}/{retries+1}). Retrying in 2s… [{e}]\")\n", " time.sleep(2)\n", "\n", "def ensure_model():\n", " MODELS_DIR.mkdir(parents=True, exist_ok=True)\n", " mp = MODELS_DIR / MODEL_NAME\n", " if not mp.exists() or mp.stat().st_size == 0:\n", " import urllib.request\n", " print(f\"Downloading model to {mp} …\")\n", " with urllib.request.urlopen(MODEL_URL) as r, open(mp, \"wb\") as f:\n", " shutil.copyfileobj(r, f)\n", " if mp.stat().st_size < 100 * 1024:\n", " raise RuntimeError(\"Downloaded model looks too small; download may have failed.\")\n", " print(f\"✅ Model ready: {mp}\")\n", "\n", "# 1) Clone main repo (Linux/NVIDIA)\n", "print(\"Linux/NVIDIA detected — using main piper-sample-generator repo.\")\n", "safe_clone(REPO_URL)\n", "\n", "# 2) Install deps\n", "# - piper-tts provides the `piper` module (required by generate_samples.py)\n", "# - piper-phonemize-cross does the phonemization\n", "# - onnxruntime-gpu enables CUDA (container must have NVIDIA runtime)\n", "deps = [\n", " \"piper-tts>=1.2.0\",\n", " \"piper-phonemize-cross==1.2.1\",\n", " \"soundfile\",\n", " \"numpy\",\n", " \"onnxruntime-gpu>=1.16.0\",\n", "]\n", "pip_install(*deps)\n", "\n", "# 3) Verify CUDA provider is available\n", "try:\n", " import onnxruntime as ort\n", " providers = ort.get_available_providers()\n", " print(f\"ONNX Runtime providers: {providers}\")\n", " if \"CUDAExecutionProvider\" not in providers:\n", " print(\"⚠️ CUDAExecutionProvider not available. \"\n", " \"The sample will still run on CPU, but check your NVIDIA container setup \"\n", " \"(nvidia-container-toolkit, runtime, and driver).\")\n", "except Exception as e:\n", " print(\"⚠️ Could not import onnxruntime to verify providers:\", e)\n", "\n", "# 4) Ensure model present\n", "ensure_model()\n", "\n", "# 5) Generate one sample\n", "AUDIO_OUT_DIR.mkdir(parents=True, exist_ok=True)\n", "gen_script = REPO_DIR / \"generate_samples.py\"\n", "if not gen_script.exists():\n", " raise FileNotFoundError(f\"Missing generator: {gen_script}\")\n", "\n", "cmd = [\n", " sys.executable, str(gen_script),\n", " TARGET_WORD,\n", " \"--model\", str(MODELS_DIR / MODEL_NAME), # ← pass the generator .pt explicitly\n", " \"--max-samples\", \"1\",\n", " \"--batch-size\", \"1\",\n", " \"--output-dir\", str(AUDIO_OUT_DIR),\n", "]\n", "run(cmd)\n", "\n", "# 6) Play the audio (if the notebook frontend supports it)\n", "if AUDIO_PATH.exists():\n", " print(f\"🎧 Playing {AUDIO_PATH}\")\n", " display(Audio(str(AUDIO_PATH), autoplay=True))\n", "else:\n", " print(f\"Audio file not found at {AUDIO_PATH}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-SvGtCCM9akR" }, "outputs": [], "source": [ "# Generate a large number of wake word samples for training\n", "import sys, subprocess\n", "from pathlib import Path\n", "\n", "REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n", "MODELS_DIR = REPO_DIR / \"models\"\n", "MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n", "\n", "cmd = [\n", " sys.executable,\n", " str(REPO_DIR / \"generate_samples.py\"),\n", " TARGET_WORD,\n", " \"--model\", str(MODELS_DIR / MODEL_NAME),\n", " \"--max-samples\", \"50000\",\n", " \"--batch-size\", \"100\",\n", " \"--output-dir\", \"generated_samples\",\n", "]\n", "\n", "print(\"→\", \" \".join(cmd))\n", "subprocess.run(cmd, check=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YJRG4Qvo9nXG" }, "outputs": [], "source": [ "# NVIDIA/Linux dataset prep to match the Apple behavior, but without datasets.Audio (no TorchCodec)\n", "# MIT RIR -> resample to 16 kHz\n", "# AudioSet -> NO resample\n", "# FMA -> resample to 16 kHz mono\n", "\n", "import os, sys, scipy.io.wavfile, numpy as np\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "import soundfile as sf\n", "import librosa\n", "from datasets import load_dataset\n", "\n", "def write_wav(dst: Path, data: np.ndarray, sr: int):\n", " x = np.clip(data, -1.0, 1.0)\n", " scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))\n", "\n", "# -----------------------------\n", "# MIT RIR (resample to 16 kHz)\n", "# -----------------------------\n", "print(\"=== MIT RIR ===\")\n", "rir_out = Path(\"mit_rirs\")\n", "rir_out.mkdir(exist_ok=True)\n", "if not any(rir_out.rglob(\"*.wav\")):\n", " ok = 0\n", " try:\n", " # Avoid datasets.Audio to keep TorchCodec out:\n", " # Use streaming=True + Audio(decode=False)-equivalent: access raw file path and decode with librosa\n", " print(\"⬇️ MIT RIR (streaming + manual decode)…\")\n", " ds = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\",\n", " split=\"train\", streaming=True)\n", " for i, row in enumerate(tqdm(ds)):\n", " try:\n", " audio_path = row[\"audio\"][\"path\"]\n", " y, sr = librosa.load(audio_path, sr=16000, mono=True)\n", " write_wav(rir_out / f\"rir_{i:04d}.wav\", y, 16000)\n", " ok += 1\n", " except Exception:\n", " pass\n", " print(f\"✅ MIT RIR saved: {ok} files\")\n", " except Exception as e:\n", " print(f\"⚠️ MIT RIR download failed: {e}\")\n", " # Fallback to official ZIP if needed (rare)\n", " try:\n", " print(\"⬇️ MIT RIR (fallback ZIP)…\")\n", " zip_url = \"https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip\"\n", " zip_path = rir_out.parent / \"MIT_RIR_Audio.zip\"\n", " if not zip_path.exists():\n", " os.system(f\"wget -q -O '{zip_path}' '{zip_url}'\")\n", " os.system(f'unzip -q -o \"{zip_path}\" -d \"{rir_out}\"')\n", " # Normalize to 16k mono\n", " for p in tqdm(list(rir_out.rglob(\"*.wav\")), desc=\"Normalize MIT RIR\"):\n", " a, sr = sf.read(p, always_2d=False)\n", " if a.ndim > 1: a = a[:,0]\n", " if sr != 16000:\n", " a, _ = librosa.load(p, sr=16000, mono=True)\n", " write_wav(p, a, 16000)\n", " print(\"✅ MIT RIR fallback complete\")\n", " except Exception as e2:\n", " print(f\"❌ MIT RIR fallback failed: {e2}\")\n", "else:\n", " print(\"✅ mit_rirs exists; skipping.\")\n", "\n", "# -----------------------------\n", "# AudioSet (NO resample — fast)\n", "# -----------------------------\n", "print(\"\\n=== AudioSet subset ===\")\n", "audioset_dir = Path(\"audioset\"); audioset_dir.mkdir(exist_ok=True)\n", "audioset_out = Path(\"audioset_16k\"); audioset_out.mkdir(exist_ok=True)\n", "\n", "links = [f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n", " for i in range(10)]\n", "for link in links:\n", " fname = link.split(\"/\")[-1]\n", " out_tar = audioset_dir / fname\n", " if not out_tar.exists():\n", " print(f\"⬇️ {fname}\")\n", " os.system(f\"wget -q -O '{out_tar}' '{link}'\")\n", " print(f\"📦 Extract {fname}\")\n", " os.system(f\"tar -xf '{out_tar}' -C '{audioset_dir}'\")\n", "\n", "flacs = list(audioset_dir.rglob(\"*.flac\"))\n", "print(f\"🔎 FLAC files: {len(flacs)}\")\n", "corrupt = []\n", "for p in tqdm(flacs, desc=\"AudioSet→WAV (no resample)\"):\n", " try:\n", " a, sr = sf.read(p, always_2d=False)\n", " if a is None or len(a) == 0:\n", " raise ValueError(\"empty audio\")\n", " if a.ndim > 1:\n", " a = a[:,0]\n", " # Apple behavior: write as 16-bit and label 16 kHz (no resample)\n", " write_wav(audioset_out / (p.stem + \".wav\"), a, 16000)\n", " except Exception as e:\n", " corrupt.append(f\"{p}:{e}\")\n", "if corrupt:\n", " (audioset_out / \"audioset_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n", "print(\"✅ AudioSet processing complete!\")\n", "\n", "# -----------------------------\n", "# FMA xsmall (resample to 16 kHz mono)\n", "# -----------------------------\n", "print(\"\\n=== FMA xsmall ===\")\n", "fma_zip_dir = Path(\"fma\"); fma_zip_dir.mkdir(exist_ok=True)\n", "fma_out = Path(\"fma_16k\"); fma_out.mkdir(exist_ok=True)\n", "\n", "zipname = \"fma_xs.zip\"\n", "zipurl = f\"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{zipname}\"\n", "zipout = fma_zip_dir / zipname\n", "if not zipout.exists():\n", " os.system(f\"wget -q -O '{zipout}' '{zipurl}'\")\n", " os.system(f\"cd fma && unzip -q '{zipname}'\")\n", "\n", "mp3s = list(Path(\"fma/fma_small\").rglob(\"*.mp3\"))\n", "print(f\"🎵 FMA mp3 count: {len(mp3s)}\")\n", "corrupt = []\n", "for p in tqdm(mp3s, desc=\"FMA→16k WAV\"):\n", " try:\n", " y, sr = librosa.load(p, sr=16000, mono=True) # proper decode+resample\n", " if y.size == 0:\n", " raise ValueError(\"empty audio\")\n", " write_wav(fma_out / (p.stem + \".wav\"), y, 16000)\n", " except Exception as e:\n", " corrupt.append(f\"{p}:{e}\")\n", "if corrupt:\n", " Path(\"fma_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n", "print(\"\\n✅ Dataset prep complete!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XW3bmbI5-JAz" }, "outputs": [], "source": [ "# Sets up the augmentations.\n", "# To improve your model, experiment with these settings and use more sources of\n", "# background clips.\n", "\n", "import os\n", "from microwakeword.audio.augmentation import Augmentation\n", "from microwakeword.audio.clips import Clips\n", "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "\n", "def validate_directories(paths):\n", " for path in paths:\n", " if not os.path.exists(path):\n", " print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n", " return False\n", " return True\n", "\n", "# Paths to augmented data\n", "impulse_paths = ['mit_rirs']\n", "background_paths = ['fma_16k', 'audioset_16k']\n", "\n", "if not validate_directories(impulse_paths + background_paths):\n", " raise ValueError(\"One or more required directories are missing.\")\n", "\n", "clips = Clips(\n", " input_directory='./generated_samples',\n", " file_pattern='*.wav',\n", " max_clip_duration_s=5,\n", " remove_silence=True,\n", " random_split_seed=10,\n", " split_count=0.1,\n", ")\n", "\n", "augmenter = Augmentation(\n", " augmentation_duration_s=3.2,\n", " augmentation_probabilities={\n", " \"SevenBandParametricEQ\": 0.1,\n", " \"TanhDistortion\": 0.05,\n", " \"PitchShift\": 0.15,\n", " \"BandStopFilter\": 0.1,\n", " \"AddColorNoise\": 0.1,\n", " \"AddBackgroundNoise\": 0.7,\n", " \"Gain\": 0.8,\n", " \"RIR\": 0.7,\n", " },\n", " impulse_paths=impulse_paths,\n", " background_paths=background_paths,\n", " background_min_snr_db=5,\n", " background_max_snr_db=10,\n", " min_jitter_s=0.2,\n", " max_jitter_s=0.3,\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V5UsJfKKD1k9" }, "outputs": [], "source": [ "# Augment a random generated-sample WAV and play it back (pass ndarray to augmenter)\n", "from pathlib import Path\n", "from IPython.display import Audio, display\n", "import numpy as np\n", "import soundfile as sf\n", "import librosa, random, glob\n", "\n", "output_dir = Path(\"./augmented_clips\")\n", "output_dir.mkdir(exist_ok=True)\n", "\n", "# 1) Pick a random WAV from the Piper outputs\n", "candidates = glob.glob(\"generated_samples/*.wav\")\n", "if not candidates:\n", " raise SystemExit(\"No files in generated_samples/. Run the TTS sample cell first.\")\n", "src_path = random.choice(candidates)\n", "\n", "# 2) Load as 16 kHz mono float32\n", "y, sr = librosa.load(src_path, sr=16000, mono=True)\n", "y = y.astype(np.float32, copy=False)\n", "\n", "# 3) Augment — microwakeword Augmentation expects a 1-D numpy array\n", "try:\n", " y_aug = augmenter.augment_clip(y)\n", "except Exception as e:\n", " # some versions accept (samples, sr) — try that as a fallback\n", " try:\n", " y_aug = augmenter.augment_clip((y, sr))\n", " except Exception:\n", " raise\n", "\n", "# 4) Save and play\n", "out_path = output_dir / \"augmented_clip.wav\"\n", "sf.write(str(out_path), y_aug.astype(np.float32, copy=False), sr, subtype=\"PCM_16\")\n", "print(f\"Augmented clip saved to {out_path}\")\n", "display(Audio(str(out_path), autoplay=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D7BHcY1mEGbK" }, "outputs": [], "source": [ "# Augment samples and save the training, validation, and testing sets.\n", "# This version avoids datasets.Audio entirely by driving Clips from local WAVs.\n", "\n", "import os, glob, random\n", "from pathlib import Path\n", "import types\n", "import numpy as np\n", "import librosa\n", "from mmap_ninja.ragged import RaggedMmap\n", "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "\n", "# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n", "def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n", " \"\"\"\n", " Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n", " Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n", " \"\"\"\n", " files = sorted(glob.glob(\"generated_samples/*.wav\"))\n", " if not files:\n", " raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n", "\n", " rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n", " files_shuf = files[:]\n", " rng.shuffle(files_shuf)\n", "\n", " n = len(files_shuf)\n", " n_val = max(1, int(0.10 * n))\n", " n_test = max(1, int(0.10 * n))\n", " n_train = max(0, n - n_val - n_test)\n", " splits = {\n", " \"train\": files_shuf[:n_train],\n", " \"validation\": files_shuf[n_train:n_train + n_val],\n", " \"test\": files_shuf[n_train + n_val:],\n", " }\n", " file_list = splits.get(split, [])\n", " if not file_list:\n", " return # nothing to yield\n", "\n", " for _ in range(max(1, int(repeat))):\n", " for p in file_list:\n", " y, sr = librosa.load(p, sr=16000, mono=True)\n", " yield y.astype(np.float32, copy=False)\n", "\n", "# Bind the patched generator to your existing `clips` instance\n", "clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n", "print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n", "\n", "# ---- Validate augmentation asset folders exist ----\n", "def validate(paths):\n", " for p in paths:\n", " if not Path(p).exists():\n", " raise SystemExit(f\"❌ Missing directory: {p}. Run dataset prep first.\")\n", "\n", "impulse_paths = [\"mit_rirs\"]\n", "background_paths = [\"fma_16k\", \"audioset_16k\"]\n", "validate(impulse_paths + background_paths)\n", "\n", "# ---- Output root ----\n", "out_root = Path(\"generated_augmented_features\")\n", "out_root.mkdir(exist_ok=True)\n", "\n", "# ---- Split config (same as before) ----\n", "split_cfg = {\n", " \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n", " \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n", " \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n", "}\n", "\n", "# ---- Generate features ----\n", "for split, cfg in split_cfg.items():\n", " out_dir = out_root / split\n", " out_dir.mkdir(parents=True, exist_ok=True)\n", " print(f\"🧪 Processing {split} …\")\n", "\n", " spectros = SpectrogramGeneration(\n", " clips=clips, # now backed by our WAV loader\n", " augmenter=augmenter, # your existing augmenter\n", " slide_frames=cfg[\"slide_frames\"],\n", " step_ms=10,\n", " )\n", "\n", " RaggedMmap.from_generator(\n", " out_dir=str(out_dir / \"wakeword_mmap\"),\n", " sample_generator=spectros.spectrogram_generator(\n", " split=cfg[\"name\"], repeat=cfg[\"repetition\"]\n", " ),\n", " batch_size=100,\n", " verbose=True,\n", " )\n", "\n", "print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1pGuJDPyp3ax" }, "outputs": [], "source": [ "# Downloads pre-generated spectrogram features (made for microWakeWord in\n", "# particular) for various negative datasets. This can be slow!\n", "\n", "import os\n", "import requests\n", "import zipfile\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "\n", "# Function to download a file with progress bar\n", "def download_file(url, output_path):\n", " response = requests.get(url, stream=True)\n", " total_size = int(response.headers.get('content-length', 0))\n", " with open(output_path, \"wb\") as f, tqdm(\n", " desc=f\"Downloading {output_path.name}\",\n", " total=total_size,\n", " unit=\"B\",\n", " unit_scale=True,\n", " unit_divisor=1024,\n", " ) as bar:\n", " for chunk in response.iter_content(chunk_size=1024):\n", " f.write(chunk)\n", " bar.update(len(chunk))\n", " print(f\"Downloaded: {output_path}\")\n", "\n", "# Function to extract ZIP files\n", "def extract_zip(zip_path, extract_to):\n", " with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n", " zip_ref.extractall(extract_to)\n", " print(f\"Extracted: {zip_path} to {extract_to}\")\n", "\n", "# Directory for negative datasets\n", "output_dir = Path('./negative_datasets')\n", "output_dir.mkdir(exist_ok=True)\n", "\n", "# Negative dataset URLs\n", "link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n", "filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n", "\n", "# Download and extract files\n", "for fname in filenames:\n", " link = link_root + fname\n", " zip_path = output_dir / fname\n", "\n", " # Download only if the file doesn't already exist\n", " if not zip_path.exists():\n", " try:\n", " download_file(link, zip_path)\n", " except Exception as e:\n", " print(f\"Error downloading {fname}: {e}\")\n", " continue\n", "\n", " # Extract the ZIP file\n", " try:\n", " extract_zip(zip_path, output_dir)\n", " except Exception as e:\n", " print(f\"Error extracting {fname}: {e}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ii1A14GsGVQT" }, "outputs": [], "source": [ "# GPU memory config (set env BEFORE importing TF)\n", "import os, sys, gc\n", "\n", "if \"tensorflow\" not in sys.modules:\n", " os.environ[\"TF_FORCE_GPU_ALLOW_GROWTH\"] = \"true\" # grow as needed\n", " os.environ[\"TF_GPU_ALLOCATOR\"] = \"cuda_malloc_async\" # modern CUDA allocator\n", " os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n", " os.environ[\"TF_XLA_FLAGS\"] = \"--tf_xla_auto_jit=0\" # disable XLA JIT (more stable mem)\n", "import tensorflow as tf\n", "\n", "# Per-device memory growth (belt + suspenders)\n", "for g in tf.config.list_physical_devices(\"GPU\"):\n", " try:\n", " tf.config.experimental.set_memory_growth(g, True)\n", " except Exception:\n", " pass\n", "print(\"GPUs:\", tf.config.list_physical_devices(\"GPU\"))\n", "gc.collect()\n", "\n", "# Optional but recommended: mixed precision halves activation memory\n", "try:\n", " from tensorflow.keras import mixed_precision\n", " mixed_precision.set_global_policy(\"mixed_float16\")\n", " print(\"Mixed precision policy:\", mixed_precision.global_policy())\n", "except Exception as e:\n", " print(\"Mixed precision not enabled:\", e)\n", "\n", "# --- Save a yaml config that controls the training process ---\n", "\n", "import yaml\n", "\n", "config = {}\n", "\n", "config[\"window_step_ms\"] = 10\n", "config[\"train_dir\"] = \"trained_models/wakeword\"\n", "\n", "config[\"features\"] = [\n", " {\"features_dir\":\"generated_augmented_features\",\"sampling_weight\":2.0,\"penalty_weight\":1.0,\"truth\":True,\"truncation_strategy\":\"truncate_start\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/speech\",\"sampling_weight\":12.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/dinner_party\",\"sampling_weight\":12.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/no_speech\",\"sampling_weight\":5.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n", "]\n", "\n", "config[\"training_steps\"] = [40000]\n", "config[\"positive_class_weight\"] = [1]\n", "config[\"negative_class_weight\"] = [20]\n", "config[\"learning_rates\"] = [0.001]\n", "\n", "# Smaller batch to avoid GPU copy/alloc failures on 3070 laptop VRAM\n", "config[\"batch_size\"] = 16\n", "\n", "# SpecAugment off (as before)\n", "config[\"time_mask_max_size\"] = [0]\n", "config[\"time_mask_count\"] = [0]\n", "config[\"freq_mask_max_size\"] = [0]\n", "config[\"freq_mask_count\"] = [0]\n", "\n", "config[\"eval_step_interval\"] = 500\n", "config[\"clip_duration_ms\"] = 1500\n", "config[\"target_minimization\"] = 0.9\n", "config[\"minimization_metric\"] = None\n", "config[\"maximization_metric\"] = \"average_viable_recall\"\n", "\n", "with open(\"training_parameters.yaml\", \"w\") as f:\n", " yaml.dump(config, f)\n", "\n", "print(\"✅ Wrote training_parameters.yaml (batch_size=16) with allow_growth, cuda_malloc_async, XLA JIT OFF, mixed precision ON.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WoEXJBaiC9mf" }, "outputs": [], "source": [ "# Train + export (GPU-friendly env + stable flags)\n", "\n", "import os, sys\n", "\n", "# --- Runtime env (inherited by the subprocess we're about to launch) ---\n", "os.environ.setdefault(\"LD_LIBRARY_PATH\",\n", " \"/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/x86_64-linux-gnu:\" +\n", " os.environ.get(\"LD_LIBRARY_PATH\",\"\")\n", ")\n", "os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # quieter logs\n", "os.environ.setdefault(\"TF_FORCE_GPU_ALLOW_GROWTH\", \"true\") # grow VRAM as needed\n", "os.environ.setdefault(\"TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")# modern allocator\n", "os.environ.setdefault(\"XLA_FLAGS\", \"--xla_gpu_cuda_data_dir=/usr/local/cuda\")\n", "os.environ.setdefault(\"TF_XLA_FLAGS\", \"--tf_xla_auto_jit=0\") # disable XLA JIT (more stable)\n", "os.environ.setdefault(\"NVIDIA_TF32_OVERRIDE\", \"1\") # allow TF32 (perf/VRAM win on Ampere+)\n", "\n", "# If you still hit GPU memory errors, uncomment to force a smaller workspace:\n", "# os.environ[\"TF_CUDNN_WORKSPACE_LIMIT_IN_MB\"] = \"256\"\n", "\n", "# --- Kick off training ---\n", "cmd = f'''\"{sys.executable}\" -m microwakeword.model_train_eval \\\n", " --training_config=\"training_parameters.yaml\" \\\n", " --train 1 \\\n", " --restore_checkpoint 1 \\\n", " --test_tf_nonstreaming 0 \\\n", " --test_tflite_nonstreaming 0 \\\n", " --test_tflite_nonstreaming_quantized 0 \\\n", " --test_tflite_streaming 0 \\\n", " --test_tflite_streaming_quantized 1 \\\n", " --use_weights \"best_weights\" \\\n", " mixednet \\\n", " --pointwise_filters \"64,64,64,64\" \\\n", " --repeat_in_block \"1,1,1,1\" \\\n", " --mixconv_kernel_sizes \"[5], [7,11], [9,15], [23]\" \\\n", " --residual_connection \"0,0,0,0\" \\\n", " --first_conv_filters 32 \\\n", " --first_conv_kernel_size 5 \\\n", " --stride 2'''\n", "print(\"Running:\\n\", cmd)\n", "!$cmd" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ex_UIWvwtjAN" }, "outputs": [], "source": [ "import shutil\n", "import json\n", "from IPython.display import display, HTML\n", "\n", "# Use the wake word from Cell 3\n", "wake_word = TARGET_WORD\n", "\n", "# --- Copy TFLite file to working dir with wake word name ---\n", "source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n", "tflite_filename = f\"{wake_word}.tflite\"\n", "tflite_path = f\"./{tflite_filename}\"\n", "shutil.copy(source_path, tflite_path)\n", "\n", "# --- Write JSON metadata file with matching model name ---\n", "json_data = {\n", " \"type\": \"micro\",\n", " \"wake_word\": wake_word,\n", " \"author\": \"Tater Totterson\",\n", " \"website\": \"https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git\",\n", " \"model\": tflite_filename,\n", " \"trained_languages\": [\"en\"],\n", " \"version\": 2,\n", " \"micro\": {\n", " \"probability_cutoff\": 0.97,\n", " \"sliding_window_size\": 5,\n", " \"feature_step_size\": 10,\n", " \"tensor_arena_size\": 30000,\n", " \"minimum_esphome_version\": \"2024.7.0\"\n", " }\n", "}\n", "json_filename = f\"{wake_word}.json\"\n", "json_path = f\"./{json_filename}\"\n", "with open(json_path, \"w\") as json_file:\n", " json.dump(json_data, json_file, indent=2)\n", "\n", "# --- Display nice download links ---\n", "html = f\"\"\"\n", "

Download your files:

\n", "\n", "\"\"\"\n", "display(HTML(html))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }