mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
933 lines
38 KiB
Plaintext
933 lines
38 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 🥔 MicroWakeWord Trainer — Tater Totterson Edition\n",
|
||
"# ==================================================\n",
|
||
"# Welcome, friend! 👋 This notebook will help you train your very own wake word model.\n",
|
||
"# Think of it like teaching Tater Totterson to recognize when you say a special word.\n",
|
||
"#\n",
|
||
"# By the end, you'll have:\n",
|
||
"# ✅ A trained TensorFlow Lite model ready for on-device detection.\n",
|
||
"# ✅ A matching JSON manifest you can drop straight into ESPHome.\n",
|
||
"#\n",
|
||
"# This flow is optimized for Python 3.10 and NVIDIA GPUs (but should work elsewhere too).\n",
|
||
"# You can customize the wake word, play with training parameters, and experiment with\n",
|
||
"# different datasets until you get something that feels just right. 💪\n",
|
||
"#\n",
|
||
"# ⚡ Quick Tips:\n",
|
||
"# • Change TARGET_WORD below to whatever you want your wake word to be.\n",
|
||
"# • Rerun the notebook from the top if you change it (to regenerate everything).\n",
|
||
"# • Expect to experiment — tweaking hyperparameters is part of the fun!\n",
|
||
"#\n",
|
||
"# When you’re done, you’ll get two files:\n",
|
||
"# 1️⃣ <wakeword>.tflite — your trained model.\n",
|
||
"# 2️⃣ <wakeword>.json — a manifest for ESPHome integration.\n",
|
||
"#\n",
|
||
"# More info & examples:\n",
|
||
"# 🔗 https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker\n",
|
||
"\n",
|
||
"# --- Set your wake word here ---\n",
|
||
"TARGET_WORD = \"tater\" # 🗣️ Change this to whatever phrase you want!\n",
|
||
"print(f\"🥔 Tater Totterson is listening for: '{TARGET_WORD}'\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "BFf6511E65ff"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Installs microWakeWord. Be sure to restart the session after this is finished.\n",
|
||
"import platform\n",
|
||
"import sys\n",
|
||
"import os\n",
|
||
"\n",
|
||
"if platform.system() == \"Darwin\":\n",
|
||
" # `pymicro-features` is installed from a fork to support building on macOS\n",
|
||
" !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n",
|
||
"\n",
|
||
"# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n",
|
||
"!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n",
|
||
"\n",
|
||
"# Clone the microWakeWord repository (your fork)\n",
|
||
"repo_path = \"./microWakeWord\"\n",
|
||
"if not os.path.exists(repo_path):\n",
|
||
" print(\"⬇️ Cloning microWakeWord repository...\")\n",
|
||
" !git clone https://github.com/TaterTotterson/micro-wake-word.git {repo_path}\n",
|
||
"\n",
|
||
"# Optionally pin to a specific commit for reproducibility\n",
|
||
"os.system(f\"cd {repo_path} && git checkout ac6502bf48b5e372c47ed509f5f5ca181e6d50bb\")\n",
|
||
"\n",
|
||
"# Install editable\n",
|
||
"if os.path.exists(repo_path):\n",
|
||
" print(\"📦 Installing microWakeWord...\")\n",
|
||
" !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n",
|
||
"else:\n",
|
||
" print(f\"❌ Repository not found at {repo_path}. Clone might have failed.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "BFf6511E65ff"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# --- GPU Check (Torch + ONNX Runtime) ---\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"import onnxruntime as ort\n",
|
||
"\n",
|
||
"print(\"🔧 Torch CUDA Available:\", torch.cuda.is_available())\n",
|
||
"if torch.cuda.is_available():\n",
|
||
" print(\" • Device count:\", torch.cuda.device_count())\n",
|
||
" print(\" • Current device:\", torch.cuda.current_device())\n",
|
||
" print(\" • Device name:\", torch.cuda.get_device_name(torch.cuda.current_device()))\n",
|
||
"else:\n",
|
||
" print(\"⚠️ Torch cannot see a GPU — check Docker runtime (--gpus all) and nvidia-container-toolkit\")\n",
|
||
"\n",
|
||
"print(\"\\n🔧 ONNX Runtime Providers:\")\n",
|
||
"try:\n",
|
||
" providers = ort.get_available_providers()\n",
|
||
" print(\" •\", providers)\n",
|
||
" if \"CUDAExecutionProvider\" not in providers:\n",
|
||
" print(\"⚠️ CUDAExecutionProvider not available — ONNX will fall back to CPU.\")\n",
|
||
"except Exception as e:\n",
|
||
" print(\"⚠️ Could not query ONNX Runtime providers:\", e)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "dEluu7nL7ywd"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# NVIDIA Linux Docker: generate 1 sample of the target word (robust + CUDA check)\n",
|
||
"\n",
|
||
"import os, sys, shutil, subprocess, time, platform\n",
|
||
"from pathlib import Path\n",
|
||
"from IPython.display import Audio, display\n",
|
||
"\n",
|
||
"REPO_URL = \"https://github.com/rhasspy/piper-sample-generator\"\n",
|
||
"REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n",
|
||
"MODELS_DIR = REPO_DIR / \"models\"\n",
|
||
"MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n",
|
||
"MODEL_URL = f\"https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/{MODEL_NAME}\"\n",
|
||
"AUDIO_OUT_DIR = Path.cwd() / \"generated_samples\"\n",
|
||
"AUDIO_PATH = AUDIO_OUT_DIR / \"0.wav\"\n",
|
||
"\n",
|
||
"def run(cmd, check=True):\n",
|
||
" print(\"→\", \" \".join(cmd))\n",
|
||
" proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n",
|
||
" for line in proc.stdout:\n",
|
||
" print(line, end=\"\")\n",
|
||
" rc = proc.wait()\n",
|
||
" if check and rc != 0:\n",
|
||
" raise RuntimeError(f\"Command failed with exit code {rc}: {' '.join(cmd)}\")\n",
|
||
" return rc\n",
|
||
"\n",
|
||
"def pip_install(*pkgs):\n",
|
||
" run([sys.executable, \"-m\", \"pip\", \"install\", \"--upgrade\", \"pip\"], check=False)\n",
|
||
" run([sys.executable, \"-m\", \"pip\", \"install\", *pkgs])\n",
|
||
"\n",
|
||
"def safe_clone(repo_url, branch=None, dest=REPO_DIR, retries=2):\n",
|
||
" if dest.exists() and not (dest / \".git\").exists():\n",
|
||
" print(\"⚠️ Found partial clone. Removing…\")\n",
|
||
" shutil.rmtree(dest, ignore_errors=True)\n",
|
||
" if not dest.exists():\n",
|
||
" for i in range(retries + 1):\n",
|
||
" try:\n",
|
||
" cmd = [\"git\", \"clone\", \"--depth\", \"1\", repo_url, str(dest)]\n",
|
||
" if branch:\n",
|
||
" cmd = [\"git\", \"clone\", \"--depth\", \"1\", \"--branch\", branch, repo_url, str(dest)]\n",
|
||
" run(cmd)\n",
|
||
" break\n",
|
||
" except Exception as e:\n",
|
||
" if i == retries:\n",
|
||
" raise\n",
|
||
" print(f\"Clone failed ({i+1}/{retries+1}). Retrying in 2s… [{e}]\")\n",
|
||
" time.sleep(2)\n",
|
||
"\n",
|
||
"def ensure_model():\n",
|
||
" MODELS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
" mp = MODELS_DIR / MODEL_NAME\n",
|
||
" if not mp.exists() or mp.stat().st_size == 0:\n",
|
||
" import urllib.request\n",
|
||
" print(f\"Downloading model to {mp} …\")\n",
|
||
" with urllib.request.urlopen(MODEL_URL) as r, open(mp, \"wb\") as f:\n",
|
||
" shutil.copyfileobj(r, f)\n",
|
||
" if mp.stat().st_size < 100 * 1024:\n",
|
||
" raise RuntimeError(\"Downloaded model looks too small; download may have failed.\")\n",
|
||
" print(f\"✅ Model ready: {mp}\")\n",
|
||
"\n",
|
||
"# 1) Clone main repo (Linux/NVIDIA)\n",
|
||
"print(\"Linux/NVIDIA detected — using main piper-sample-generator repo.\")\n",
|
||
"safe_clone(REPO_URL)\n",
|
||
"\n",
|
||
"# 2) Install deps\n",
|
||
"# - piper-tts provides the `piper` module (required by generate_samples.py)\n",
|
||
"# - piper-phonemize-cross does the phonemization\n",
|
||
"# - onnxruntime-gpu enables CUDA (container must have NVIDIA runtime)\n",
|
||
"deps = [\n",
|
||
" \"piper-tts>=1.2.0\",\n",
|
||
" \"piper-phonemize-cross==1.2.1\",\n",
|
||
" \"soundfile\",\n",
|
||
" \"numpy\",\n",
|
||
" \"onnxruntime-gpu>=1.16.0\",\n",
|
||
"]\n",
|
||
"pip_install(*deps)\n",
|
||
"\n",
|
||
"# 3) Verify CUDA provider is available\n",
|
||
"try:\n",
|
||
" import onnxruntime as ort\n",
|
||
" providers = ort.get_available_providers()\n",
|
||
" print(f\"ONNX Runtime providers: {providers}\")\n",
|
||
" if \"CUDAExecutionProvider\" not in providers:\n",
|
||
" print(\"⚠️ CUDAExecutionProvider not available. \"\n",
|
||
" \"The sample will still run on CPU, but check your NVIDIA container setup \"\n",
|
||
" \"(nvidia-container-toolkit, runtime, and driver).\")\n",
|
||
"except Exception as e:\n",
|
||
" print(\"⚠️ Could not import onnxruntime to verify providers:\", e)\n",
|
||
"\n",
|
||
"# 4) Ensure model present\n",
|
||
"ensure_model()\n",
|
||
"\n",
|
||
"# 5) Generate one sample\n",
|
||
"AUDIO_OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
"gen_script = REPO_DIR / \"generate_samples.py\"\n",
|
||
"if not gen_script.exists():\n",
|
||
" raise FileNotFoundError(f\"Missing generator: {gen_script}\")\n",
|
||
"\n",
|
||
"cmd = [\n",
|
||
" sys.executable, str(gen_script),\n",
|
||
" TARGET_WORD,\n",
|
||
" \"--model\", str(MODELS_DIR / MODEL_NAME), # ← pass the generator .pt explicitly\n",
|
||
" \"--max-samples\", \"1\",\n",
|
||
" \"--batch-size\", \"1\",\n",
|
||
" \"--output-dir\", str(AUDIO_OUT_DIR),\n",
|
||
"]\n",
|
||
"run(cmd)\n",
|
||
"\n",
|
||
"# 6) Play the audio (if the notebook frontend supports it)\n",
|
||
"if AUDIO_PATH.exists():\n",
|
||
" print(f\"🎧 Playing {AUDIO_PATH}\")\n",
|
||
" display(Audio(str(AUDIO_PATH), autoplay=True))\n",
|
||
"else:\n",
|
||
" print(f\"Audio file not found at {AUDIO_PATH}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "-SvGtCCM9akR"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Generate a large number of wake word samples for training\n",
|
||
"import sys, subprocess\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"REPO_DIR = Path.cwd() / \"piper-sample-generator\"\n",
|
||
"MODELS_DIR = REPO_DIR / \"models\"\n",
|
||
"MODEL_NAME = \"en_US-libritts_r-medium.pt\"\n",
|
||
"\n",
|
||
"cmd = [\n",
|
||
" sys.executable,\n",
|
||
" str(REPO_DIR / \"generate_samples.py\"),\n",
|
||
" TARGET_WORD,\n",
|
||
" \"--model\", str(MODELS_DIR / MODEL_NAME),\n",
|
||
" \"--max-samples\", \"50000\",\n",
|
||
" \"--batch-size\", \"100\",\n",
|
||
" \"--output-dir\", \"generated_samples\",\n",
|
||
"]\n",
|
||
"\n",
|
||
"print(\"→\", \" \".join(cmd))\n",
|
||
"subprocess.run(cmd, check=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "YJRG4Qvo9nXG"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# NVIDIA/Linux dataset prep to match the Apple behavior, but with pinned AudioSet\n",
|
||
"# MIT RIR -> resample to 16 kHz\n",
|
||
"# AudioSet -> fetch from a working HF revision, convert to 16 kHz mono, skip bad\n",
|
||
"# FMA -> resample to 16 kHz mono\n",
|
||
"\n",
|
||
"import os, sys, subprocess, scipy.io.wavfile, numpy as np\n",
|
||
"from pathlib import Path\n",
|
||
"from tqdm import tqdm\n",
|
||
"import soundfile as sf\n",
|
||
"import librosa\n",
|
||
"from datasets import load_dataset\n",
|
||
"\n",
|
||
"# -------------------------------------------------\n",
|
||
"# small shell helpers (for curl/tar probing)\n",
|
||
"# -------------------------------------------------\n",
|
||
"def sh(cmd: str) -> int:\n",
|
||
" return subprocess.call(cmd, shell=True)\n",
|
||
"\n",
|
||
"def curl(url: str, out: Path) -> int:\n",
|
||
" # -L follow, -s silent, --fail to get nonzero on 404\n",
|
||
" return subprocess.call(f\"curl -L -s --fail '{url}' -o '{out}'\", shell=True)\n",
|
||
"\n",
|
||
"def write_wav(dst: Path, data: np.ndarray, sr: int):\n",
|
||
" x = np.clip(data, -1.0, 1.0)\n",
|
||
" scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))\n",
|
||
"\n",
|
||
"# -----------------------------\n",
|
||
"# MIT RIR (resample to 16 kHz)\n",
|
||
"# -----------------------------\n",
|
||
"print(\"=== MIT RIR ===\")\n",
|
||
"rir_out = Path(\"mit_rirs\")\n",
|
||
"rir_out.mkdir(exist_ok=True)\n",
|
||
"if not any(rir_out.rglob(\"*.wav\")):\n",
|
||
" ok = 0\n",
|
||
" try:\n",
|
||
" # Avoid datasets.Audio to keep TorchCodec out:\n",
|
||
" # Use streaming=True + manual decode with librosa\n",
|
||
" print(\"⬇️ MIT RIR (streaming + manual decode)…\")\n",
|
||
" ds = load_dataset(\n",
|
||
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
|
||
" split=\"train\",\n",
|
||
" streaming=True\n",
|
||
" )\n",
|
||
" for i, row in enumerate(tqdm(ds)):\n",
|
||
" try:\n",
|
||
" audio_path = row[\"audio\"][\"path\"]\n",
|
||
" y, sr = librosa.load(audio_path, sr=16000, mono=True)\n",
|
||
" write_wav(rir_out / f\"rir_{i:04d}.wav\", y, 16000)\n",
|
||
" ok += 1\n",
|
||
" except Exception:\n",
|
||
" pass\n",
|
||
" print(f\"✅ MIT RIR saved: {ok} files\")\n",
|
||
" except Exception as e:\n",
|
||
" print(f\"⚠️ MIT RIR download failed: {e}\")\n",
|
||
" # Fallback ZIP route\n",
|
||
" try:\n",
|
||
" print(\"⬇️ MIT RIR (fallback ZIP)…\")\n",
|
||
" zip_url = \"https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip\"\n",
|
||
" zip_path = rir_out.parent / \"MIT_RIR_Audio.zip\"\n",
|
||
" if not zip_path.exists():\n",
|
||
" os.system(f\"wget -q -O '{zip_path}' '{zip_url}'\")\n",
|
||
" os.system(f'unzip -q -o \"{zip_path}\" -d \"{rir_out}\"')\n",
|
||
" # Normalize to 16k mono\n",
|
||
" for p in tqdm(list(rir_out.rglob(\"*.wav\")), desc=\"Normalize MIT RIR\"):\n",
|
||
" a, sr = sf.read(p, always_2d=False)\n",
|
||
" if a.ndim > 1:\n",
|
||
" a = a[:, 0]\n",
|
||
" if sr != 16000:\n",
|
||
" a, _ = librosa.load(p, sr=16000, mono=True)\n",
|
||
" write_wav(p, a, 16000)\n",
|
||
" print(\"✅ MIT RIR fallback complete\")\n",
|
||
" except Exception as e2:\n",
|
||
" print(f\"❌ MIT RIR fallback failed: {e2}\")\n",
|
||
"else:\n",
|
||
" print(\"✅ mit_rirs exists; skipping.\")\n",
|
||
"\n",
|
||
"# ============================================================\n",
|
||
"# AudioSet (pinned FLAC .tar → 16k mono, skip bad files)\n",
|
||
"# ============================================================\n",
|
||
"print(\"\\n=== AudioSet subset (pinned FLAC .tar → 16k mono) ===\")\n",
|
||
"audioset_dir = Path(\"audioset\"); audioset_dir.mkdir(exist_ok=True)\n",
|
||
"audioset_out = Path(\"audioset_16k\"); audioset_out.mkdir(exist_ok=True)\n",
|
||
"\n",
|
||
"if any(audioset_out.rglob(\"*.wav\")):\n",
|
||
" print(\"✅ audioset_16k exists; skipping.\")\n",
|
||
"else:\n",
|
||
" # commits / refs we know about — we’ll probe them\n",
|
||
" REV_CANDIDATES = [\n",
|
||
" \"6762f044d1c88619c7f2006486036192128fb07e\",\n",
|
||
" \"0049167e89f259a010c3f070fe3666d9e5242836\",\n",
|
||
" \"ceb9eaaa7844c9ad7351e659c84a572e376ad06d\",\n",
|
||
" \"main\", # last resort\n",
|
||
" ]\n",
|
||
" # possible folder layouts\n",
|
||
" TAR_PATTERNS = [\n",
|
||
" \"data/bal_train0{idx}.tar\",\n",
|
||
" \"data/bal_train/bal_train0{idx}.tar\",\n",
|
||
" ]\n",
|
||
"\n",
|
||
" def find_working_rev():\n",
|
||
" for rev in REV_CANDIDATES:\n",
|
||
" for pat in TAR_PATTERNS:\n",
|
||
" probe = f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/{rev}/{pat.format(idx=0)}\"\n",
|
||
" rc = sh(f\"curl -I -L --fail -s '{probe}' > /dev/null\")\n",
|
||
" if rc == 0:\n",
|
||
" return rev, pat\n",
|
||
" return None, None\n",
|
||
"\n",
|
||
" rev, pattern = find_working_rev()\n",
|
||
" if rev is None:\n",
|
||
" raise RuntimeError(\"Could not locate an AudioSet revision with FLAC tarballs still present on HF.\")\n",
|
||
"\n",
|
||
" print(f\"📌 Using AudioSet revision: {rev}\")\n",
|
||
" print(f\"🗂️ Tar layout pattern: {pattern}\")\n",
|
||
"\n",
|
||
" # download + extract bal_train00..09\n",
|
||
" for i in range(10):\n",
|
||
" rel = pattern.format(idx=i)\n",
|
||
" url = f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/{rev}/{rel}\"\n",
|
||
" fname = rel.split(\"/\")[-1]\n",
|
||
" out_tar = audioset_dir / fname\n",
|
||
" if not out_tar.exists():\n",
|
||
" print(f\"⬇️ {fname}\")\n",
|
||
" rc = curl(url, out_tar)\n",
|
||
" if rc != 0:\n",
|
||
" print(f\"⚠️ Could not fetch {fname} at rev {rev}; continuing.\")\n",
|
||
" continue\n",
|
||
" print(f\"📦 Extract {fname}\")\n",
|
||
" rc = sh(f\"tar -xf '{out_tar}' -C '{audioset_dir}'\")\n",
|
||
" if rc != 0:\n",
|
||
" print(f\"⚠️ tar extract failed for {fname}; continuing.\")\n",
|
||
"\n",
|
||
" # convert FLAC → 16k mono WAV\n",
|
||
" flacs = list(audioset_dir.rglob(\"*.flac\"))\n",
|
||
" print(f\"🔎 FLAC files: {len(flacs)}\")\n",
|
||
" audioset_bad = []\n",
|
||
" ok = 0\n",
|
||
" for p in tqdm(flacs, desc=\"AudioSet→WAV (resample 16k mono)\"):\n",
|
||
" try:\n",
|
||
" y, _ = librosa.load(p, sr=16000, mono=True)\n",
|
||
" if y.size == 0:\n",
|
||
" raise ValueError(\"empty audio\")\n",
|
||
" write_wav(audioset_out / (p.stem + \".wav\"), y, 16000)\n",
|
||
" ok += 1\n",
|
||
" except Exception as e:\n",
|
||
" audioset_bad.append(f\"{p}:{e}\")\n",
|
||
"\n",
|
||
" if audioset_bad:\n",
|
||
" (audioset_out / \"audioset_corrupted_files.log\").write_text(\"\\n\".join(audioset_bad))\n",
|
||
" print(f\"✅ AudioSet complete ({ok} ok, {len(audioset_bad)} failed)\")\n",
|
||
"\n",
|
||
"# -----------------------------\n",
|
||
"# FMA xsmall (resample to 16 kHz mono)\n",
|
||
"# -----------------------------\n",
|
||
"print(\"\\n=== FMA xsmall ===\")\n",
|
||
"fma_zip_dir = Path(\"fma\"); fma_zip_dir.mkdir(exist_ok=True)\n",
|
||
"fma_out = Path(\"fma_16k\"); fma_out.mkdir(exist_ok=True)\n",
|
||
"\n",
|
||
"zipname = \"fma_xs.zip\"\n",
|
||
"zipurl = f\"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{zipname}\"\n",
|
||
"zipout = fma_zip_dir / zipname\n",
|
||
"if not zipout.exists():\n",
|
||
" os.system(f\"wget -q -O '{zipout}' '{zipurl}'\")\n",
|
||
" os.system(f\"cd fma && unzip -q '{zipname}'\")\n",
|
||
"\n",
|
||
"mp3s = list(Path(\"fma/fma_small\").rglob(\"*.mp3\"))\n",
|
||
"print(f\"🎵 FMA mp3 count: {len(mp3s)}\")\n",
|
||
"corrupt = []\n",
|
||
"for p in tqdm(mp3s, desc=\"FMA→16k WAV\"):\n",
|
||
" try:\n",
|
||
" y, sr = librosa.load(p, sr=16000, mono=True)\n",
|
||
" if y.size == 0:\n",
|
||
" raise ValueError(\"empty audio\")\n",
|
||
" write_wav(fma_out / (p.stem + \".wav\"), y, 16000)\n",
|
||
" except Exception as e:\n",
|
||
" corrupt.append(f\"{p}:{e}\")\n",
|
||
"if corrupt:\n",
|
||
" Path(\"fma_corrupted_files.log\").write_text(\"\\n\".join(corrupt))\n",
|
||
"\n",
|
||
"print(\"\\n✅ Dataset prep complete!\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "XW3bmbI5-JAz"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Sets up the augmentations.\n",
|
||
"# To improve your model, experiment with these settings and use more sources of\n",
|
||
"# background clips.\n",
|
||
"\n",
|
||
"import os\n",
|
||
"from microwakeword.audio.augmentation import Augmentation\n",
|
||
"from microwakeword.audio.clips import Clips\n",
|
||
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||
"\n",
|
||
"def validate_directories(paths):\n",
|
||
" for path in paths:\n",
|
||
" if not os.path.exists(path):\n",
|
||
" print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n",
|
||
" return False\n",
|
||
" return True\n",
|
||
"\n",
|
||
"# Paths to augmented data\n",
|
||
"impulse_paths = ['mit_rirs']\n",
|
||
"background_paths = ['fma_16k', 'audioset_16k']\n",
|
||
"\n",
|
||
"if not validate_directories(impulse_paths + background_paths):\n",
|
||
" raise ValueError(\"One or more required directories are missing.\")\n",
|
||
"\n",
|
||
"clips = Clips(\n",
|
||
" input_directory='./generated_samples',\n",
|
||
" file_pattern='*.wav',\n",
|
||
" max_clip_duration_s=5,\n",
|
||
" remove_silence=True,\n",
|
||
" random_split_seed=10,\n",
|
||
" split_count=0.1,\n",
|
||
")\n",
|
||
"\n",
|
||
"augmenter = Augmentation(\n",
|
||
" augmentation_duration_s=3.2,\n",
|
||
" augmentation_probabilities={\n",
|
||
" \"SevenBandParametricEQ\": 0.1,\n",
|
||
" \"TanhDistortion\": 0.05,\n",
|
||
" \"PitchShift\": 0.15,\n",
|
||
" \"BandStopFilter\": 0.1,\n",
|
||
" \"AddColorNoise\": 0.1,\n",
|
||
" \"AddBackgroundNoise\": 0.7,\n",
|
||
" \"Gain\": 0.8,\n",
|
||
" \"RIR\": 0.7,\n",
|
||
" },\n",
|
||
" impulse_paths=impulse_paths,\n",
|
||
" background_paths=background_paths,\n",
|
||
" background_min_snr_db=5,\n",
|
||
" background_max_snr_db=10,\n",
|
||
" min_jitter_s=0.2,\n",
|
||
" max_jitter_s=0.3,\n",
|
||
")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "V5UsJfKKD1k9"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Augment a random generated-sample WAV and play it back (pass ndarray to augmenter)\n",
|
||
"from pathlib import Path\n",
|
||
"from IPython.display import Audio, display\n",
|
||
"import numpy as np\n",
|
||
"import soundfile as sf\n",
|
||
"import librosa, random, glob\n",
|
||
"\n",
|
||
"output_dir = Path(\"./augmented_clips\")\n",
|
||
"output_dir.mkdir(exist_ok=True)\n",
|
||
"\n",
|
||
"# 1) Pick a random WAV from the Piper outputs\n",
|
||
"candidates = glob.glob(\"generated_samples/*.wav\")\n",
|
||
"if not candidates:\n",
|
||
" raise SystemExit(\"No files in generated_samples/. Run the TTS sample cell first.\")\n",
|
||
"src_path = random.choice(candidates)\n",
|
||
"\n",
|
||
"# 2) Load as 16 kHz mono float32\n",
|
||
"y, sr = librosa.load(src_path, sr=16000, mono=True)\n",
|
||
"y = y.astype(np.float32, copy=False)\n",
|
||
"\n",
|
||
"# 3) Augment — microwakeword Augmentation expects a 1-D numpy array\n",
|
||
"try:\n",
|
||
" y_aug = augmenter.augment_clip(y)\n",
|
||
"except Exception as e:\n",
|
||
" # some versions accept (samples, sr) — try that as a fallback\n",
|
||
" try:\n",
|
||
" y_aug = augmenter.augment_clip((y, sr))\n",
|
||
" except Exception:\n",
|
||
" raise\n",
|
||
"\n",
|
||
"# 4) Save and play\n",
|
||
"out_path = output_dir / \"augmented_clip.wav\"\n",
|
||
"sf.write(str(out_path), y_aug.astype(np.float32, copy=False), sr, subtype=\"PCM_16\")\n",
|
||
"print(f\"Augmented clip saved to {out_path}\")\n",
|
||
"display(Audio(str(out_path), autoplay=True))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "D7BHcY1mEGbK"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Augment samples and save the training, validation, and testing sets.\n",
|
||
"# This version avoids datasets.Audio entirely by driving Clips from local WAVs.\n",
|
||
"\n",
|
||
"import os, glob, random\n",
|
||
"from pathlib import Path\n",
|
||
"import types\n",
|
||
"import numpy as np\n",
|
||
"import librosa\n",
|
||
"from mmap_ninja.ragged import RaggedMmap\n",
|
||
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||
"\n",
|
||
"# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n",
|
||
"def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n",
|
||
" \"\"\"\n",
|
||
" Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n",
|
||
" Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n",
|
||
" \"\"\"\n",
|
||
" files = sorted(glob.glob(\"generated_samples/*.wav\"))\n",
|
||
" if not files:\n",
|
||
" raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n",
|
||
"\n",
|
||
" rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n",
|
||
" files_shuf = files[:]\n",
|
||
" rng.shuffle(files_shuf)\n",
|
||
"\n",
|
||
" n = len(files_shuf)\n",
|
||
" n_val = max(1, int(0.10 * n))\n",
|
||
" n_test = max(1, int(0.10 * n))\n",
|
||
" n_train = max(0, n - n_val - n_test)\n",
|
||
" splits = {\n",
|
||
" \"train\": files_shuf[:n_train],\n",
|
||
" \"validation\": files_shuf[n_train:n_train + n_val],\n",
|
||
" \"test\": files_shuf[n_train + n_val:],\n",
|
||
" }\n",
|
||
" file_list = splits.get(split, [])\n",
|
||
" if not file_list:\n",
|
||
" return # nothing to yield\n",
|
||
"\n",
|
||
" for _ in range(max(1, int(repeat))):\n",
|
||
" for p in file_list:\n",
|
||
" y, sr = librosa.load(p, sr=16000, mono=True)\n",
|
||
" yield y.astype(np.float32, copy=False)\n",
|
||
"\n",
|
||
"# Bind the patched generator to your existing `clips` instance\n",
|
||
"clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n",
|
||
"print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n",
|
||
"\n",
|
||
"# ---- Validate augmentation asset folders exist ----\n",
|
||
"def validate(paths):\n",
|
||
" for p in paths:\n",
|
||
" if not Path(p).exists():\n",
|
||
" raise SystemExit(f\"❌ Missing directory: {p}. Run dataset prep first.\")\n",
|
||
"\n",
|
||
"impulse_paths = [\"mit_rirs\"]\n",
|
||
"background_paths = [\"fma_16k\", \"audioset_16k\"]\n",
|
||
"validate(impulse_paths + background_paths)\n",
|
||
"\n",
|
||
"# ---- Output root ----\n",
|
||
"out_root = Path(\"generated_augmented_features\")\n",
|
||
"out_root.mkdir(exist_ok=True)\n",
|
||
"\n",
|
||
"# ---- Split config (same as before) ----\n",
|
||
"split_cfg = {\n",
|
||
" \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n",
|
||
" \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n",
|
||
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
||
"}\n",
|
||
"\n",
|
||
"# ---- Generate features ----\n",
|
||
"for split, cfg in split_cfg.items():\n",
|
||
" out_dir = out_root / split\n",
|
||
" out_dir.mkdir(parents=True, exist_ok=True)\n",
|
||
" print(f\"🧪 Processing {split} …\")\n",
|
||
"\n",
|
||
" spectros = SpectrogramGeneration(\n",
|
||
" clips=clips, # now backed by our WAV loader\n",
|
||
" augmenter=augmenter, # your existing augmenter\n",
|
||
" slide_frames=cfg[\"slide_frames\"],\n",
|
||
" step_ms=10,\n",
|
||
" )\n",
|
||
"\n",
|
||
" RaggedMmap.from_generator(\n",
|
||
" out_dir=str(out_dir / \"wakeword_mmap\"),\n",
|
||
" sample_generator=spectros.spectrogram_generator(\n",
|
||
" split=cfg[\"name\"], repeat=cfg[\"repetition\"]\n",
|
||
" ),\n",
|
||
" batch_size=100,\n",
|
||
" verbose=True,\n",
|
||
" )\n",
|
||
"\n",
|
||
"print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "1pGuJDPyp3ax"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Downloads pre-generated spectrogram features (made for microWakeWord in\n",
|
||
"# particular) for various negative datasets. This can be slow!\n",
|
||
"\n",
|
||
"import os\n",
|
||
"import requests\n",
|
||
"import zipfile\n",
|
||
"from pathlib import Path\n",
|
||
"from tqdm import tqdm\n",
|
||
"\n",
|
||
"# Function to download a file with progress bar\n",
|
||
"def download_file(url, output_path):\n",
|
||
" response = requests.get(url, stream=True)\n",
|
||
" total_size = int(response.headers.get('content-length', 0))\n",
|
||
" with open(output_path, \"wb\") as f, tqdm(\n",
|
||
" desc=f\"Downloading {output_path.name}\",\n",
|
||
" total=total_size,\n",
|
||
" unit=\"B\",\n",
|
||
" unit_scale=True,\n",
|
||
" unit_divisor=1024,\n",
|
||
" ) as bar:\n",
|
||
" for chunk in response.iter_content(chunk_size=1024):\n",
|
||
" f.write(chunk)\n",
|
||
" bar.update(len(chunk))\n",
|
||
" print(f\"Downloaded: {output_path}\")\n",
|
||
"\n",
|
||
"# Function to extract ZIP files\n",
|
||
"def extract_zip(zip_path, extract_to):\n",
|
||
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
||
" zip_ref.extractall(extract_to)\n",
|
||
" print(f\"Extracted: {zip_path} to {extract_to}\")\n",
|
||
"\n",
|
||
"# Directory for negative datasets\n",
|
||
"output_dir = Path('./negative_datasets')\n",
|
||
"output_dir.mkdir(exist_ok=True)\n",
|
||
"\n",
|
||
"# Negative dataset URLs\n",
|
||
"link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n",
|
||
"filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n",
|
||
"\n",
|
||
"# Download and extract files\n",
|
||
"for fname in filenames:\n",
|
||
" link = link_root + fname\n",
|
||
" zip_path = output_dir / fname\n",
|
||
"\n",
|
||
" # Download only if the file doesn't already exist\n",
|
||
" if not zip_path.exists():\n",
|
||
" try:\n",
|
||
" download_file(link, zip_path)\n",
|
||
" except Exception as e:\n",
|
||
" print(f\"Error downloading {fname}: {e}\")\n",
|
||
" continue\n",
|
||
"\n",
|
||
" # Extract the ZIP file\n",
|
||
" try:\n",
|
||
" extract_zip(zip_path, output_dir)\n",
|
||
" except Exception as e:\n",
|
||
" print(f\"Error extracting {fname}: {e}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "Ii1A14GsGVQT"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# GPU memory config (set env BEFORE importing TF)\n",
|
||
"import os, sys, gc\n",
|
||
"\n",
|
||
"if \"tensorflow\" not in sys.modules:\n",
|
||
" os.environ[\"TF_FORCE_GPU_ALLOW_GROWTH\"] = \"true\" # grow as needed\n",
|
||
" os.environ[\"TF_GPU_ALLOCATOR\"] = \"cuda_malloc_async\" # modern CUDA allocator\n",
|
||
" os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n",
|
||
" os.environ[\"TF_XLA_FLAGS\"] = \"--tf_xla_auto_jit=0\" # disable XLA JIT (more stable mem)\n",
|
||
"import tensorflow as tf\n",
|
||
"\n",
|
||
"# Per-device memory growth (belt + suspenders)\n",
|
||
"for g in tf.config.list_physical_devices(\"GPU\"):\n",
|
||
" try:\n",
|
||
" tf.config.experimental.set_memory_growth(g, True)\n",
|
||
" except Exception:\n",
|
||
" pass\n",
|
||
"print(\"GPUs:\", tf.config.list_physical_devices(\"GPU\"))\n",
|
||
"gc.collect()\n",
|
||
"\n",
|
||
"# Optional but recommended: mixed precision halves activation memory\n",
|
||
"try:\n",
|
||
" from tensorflow.keras import mixed_precision\n",
|
||
" mixed_precision.set_global_policy(\"mixed_float16\")\n",
|
||
" print(\"Mixed precision policy:\", mixed_precision.global_policy())\n",
|
||
"except Exception as e:\n",
|
||
" print(\"Mixed precision not enabled:\", e)\n",
|
||
"\n",
|
||
"# --- Save a yaml config that controls the training process ---\n",
|
||
"\n",
|
||
"import yaml\n",
|
||
"\n",
|
||
"config = {}\n",
|
||
"\n",
|
||
"config[\"window_step_ms\"] = 10\n",
|
||
"config[\"train_dir\"] = \"trained_models/wakeword\"\n",
|
||
"\n",
|
||
"config[\"features\"] = [\n",
|
||
" {\"features_dir\":\"generated_augmented_features\",\"sampling_weight\":2.0,\"penalty_weight\":1.0,\"truth\":True,\"truncation_strategy\":\"truncate_start\",\"type\":\"mmap\"},\n",
|
||
" {\"features_dir\":\"negative_datasets/speech\",\"sampling_weight\":12.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n",
|
||
" {\"features_dir\":\"negative_datasets/dinner_party\",\"sampling_weight\":12.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n",
|
||
" {\"features_dir\":\"negative_datasets/no_speech\",\"sampling_weight\":5.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"random\",\"type\":\"mmap\"},\n",
|
||
" {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n",
|
||
"]\n",
|
||
"\n",
|
||
"config[\"training_steps\"] = [40000]\n",
|
||
"config[\"positive_class_weight\"] = [1]\n",
|
||
"config[\"negative_class_weight\"] = [20]\n",
|
||
"config[\"learning_rates\"] = [0.001]\n",
|
||
"\n",
|
||
"# Smaller batch to avoid GPU copy/alloc failures on 3070 laptop VRAM\n",
|
||
"config[\"batch_size\"] = 16\n",
|
||
"\n",
|
||
"# SpecAugment off (as before)\n",
|
||
"config[\"time_mask_max_size\"] = [0]\n",
|
||
"config[\"time_mask_count\"] = [0]\n",
|
||
"config[\"freq_mask_max_size\"] = [0]\n",
|
||
"config[\"freq_mask_count\"] = [0]\n",
|
||
"\n",
|
||
"config[\"eval_step_interval\"] = 500\n",
|
||
"config[\"clip_duration_ms\"] = 1500\n",
|
||
"config[\"target_minimization\"] = 0.9\n",
|
||
"config[\"minimization_metric\"] = None\n",
|
||
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
||
"\n",
|
||
"with open(\"training_parameters.yaml\", \"w\") as f:\n",
|
||
" yaml.dump(config, f)\n",
|
||
"\n",
|
||
"print(\"✅ Wrote training_parameters.yaml (batch_size=16) with allow_growth, cuda_malloc_async, XLA JIT OFF, mixed precision ON.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "WoEXJBaiC9mf"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Train + export (GPU-friendly env + stable flags)\n",
|
||
"\n",
|
||
"import os, sys\n",
|
||
"\n",
|
||
"# --- Runtime env (inherited by the subprocess we're about to launch) ---\n",
|
||
"os.environ.setdefault(\"LD_LIBRARY_PATH\",\n",
|
||
" \"/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/x86_64-linux-gnu:\" +\n",
|
||
" os.environ.get(\"LD_LIBRARY_PATH\",\"\")\n",
|
||
")\n",
|
||
"os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # quieter logs\n",
|
||
"os.environ.setdefault(\"TF_FORCE_GPU_ALLOW_GROWTH\", \"true\") # grow VRAM as needed\n",
|
||
"os.environ.setdefault(\"TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")# modern allocator\n",
|
||
"os.environ.setdefault(\"XLA_FLAGS\", \"--xla_gpu_cuda_data_dir=/usr/local/cuda\")\n",
|
||
"os.environ.setdefault(\"TF_XLA_FLAGS\", \"--tf_xla_auto_jit=0\") # disable XLA JIT (more stable)\n",
|
||
"os.environ.setdefault(\"NVIDIA_TF32_OVERRIDE\", \"1\") # allow TF32 (perf/VRAM win on Ampere+)\n",
|
||
"\n",
|
||
"# If you still hit GPU memory errors, uncomment to force a smaller workspace:\n",
|
||
"# os.environ[\"TF_CUDNN_WORKSPACE_LIMIT_IN_MB\"] = \"256\"\n",
|
||
"\n",
|
||
"# --- Kick off training ---\n",
|
||
"cmd = f'''\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
|
||
" --training_config=\"training_parameters.yaml\" \\\n",
|
||
" --train 1 \\\n",
|
||
" --restore_checkpoint 1 \\\n",
|
||
" --test_tf_nonstreaming 0 \\\n",
|
||
" --test_tflite_nonstreaming 0 \\\n",
|
||
" --test_tflite_nonstreaming_quantized 0 \\\n",
|
||
" --test_tflite_streaming 0 \\\n",
|
||
" --test_tflite_streaming_quantized 1 \\\n",
|
||
" --use_weights \"best_weights\" \\\n",
|
||
" mixednet \\\n",
|
||
" --pointwise_filters \"64,64,64,64\" \\\n",
|
||
" --repeat_in_block \"1,1,1,1\" \\\n",
|
||
" --mixconv_kernel_sizes \"[5], [7,11], [9,15], [23]\" \\\n",
|
||
" --residual_connection \"0,0,0,0\" \\\n",
|
||
" --first_conv_filters 32 \\\n",
|
||
" --first_conv_kernel_size 5 \\\n",
|
||
" --stride 2'''\n",
|
||
"print(\"Running:\\n\", cmd)\n",
|
||
"!$cmd"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "ex_UIWvwtjAN"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import shutil\n",
|
||
"import json\n",
|
||
"from IPython.display import display, HTML\n",
|
||
"\n",
|
||
"# Use the wake word from Cell 3\n",
|
||
"wake_word = TARGET_WORD\n",
|
||
"\n",
|
||
"# --- Copy TFLite file to working dir with wake word name ---\n",
|
||
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
||
"tflite_filename = f\"{wake_word}.tflite\"\n",
|
||
"tflite_path = f\"./{tflite_filename}\"\n",
|
||
"shutil.copy(source_path, tflite_path)\n",
|
||
"\n",
|
||
"# --- Write JSON metadata file with matching model name ---\n",
|
||
"json_data = {\n",
|
||
" \"type\": \"micro\",\n",
|
||
" \"wake_word\": wake_word,\n",
|
||
" \"author\": \"Tater Totterson\",\n",
|
||
" \"website\": \"https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git\",\n",
|
||
" \"model\": tflite_filename,\n",
|
||
" \"trained_languages\": [\"en\"],\n",
|
||
" \"version\": 2,\n",
|
||
" \"micro\": {\n",
|
||
" \"probability_cutoff\": 0.97,\n",
|
||
" \"sliding_window_size\": 5,\n",
|
||
" \"feature_step_size\": 10,\n",
|
||
" \"tensor_arena_size\": 30000,\n",
|
||
" \"minimum_esphome_version\": \"2024.7.0\"\n",
|
||
" }\n",
|
||
"}\n",
|
||
"json_filename = f\"{wake_word}.json\"\n",
|
||
"json_path = f\"./{json_filename}\"\n",
|
||
"with open(json_path, \"w\") as json_file:\n",
|
||
" json.dump(json_data, json_file, indent=2)\n",
|
||
"\n",
|
||
"# --- Display nice download links ---\n",
|
||
"html = f\"\"\"\n",
|
||
"<h3>Download your files:</h3>\n",
|
||
"<ul>\n",
|
||
" <li><a href=\"{tflite_filename}\" download>⬇️ Download Model ({tflite_filename})</a></li>\n",
|
||
" <li><a href=\"{json_filename}\" download>⬇️ Download Metadata ({json_filename})</a></li>\n",
|
||
"</ul>\n",
|
||
"\"\"\"\n",
|
||
"display(HTML(html))"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"accelerator": "GPU",
|
||
"colab": {
|
||
"gpuType": "T4",
|
||
"provenance": []
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|