mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
778 lines
30 KiB
Plaintext
778 lines
30 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "r11cNiLqvWC6"
|
|
},
|
|
"source": [
|
|
"<div align=\"center\">\n",
|
|
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
|
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
|
|
"\n",
|
|
"**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**\n",
|
|
"\n",
|
|
"Throughout the notebook, you will find comments suggesting specific settings to modify and experiment with to enhance your model's performance.\n",
|
|
"\n",
|
|
"By the end of this notebook, you will have:\n",
|
|
"- A trained TensorFlow Lite model ready for deployment.\n",
|
|
"- A JSON manifest file to integrate the model with ESPHome.\n",
|
|
"\n",
|
|
"To use the generated model in ESPHome, refer to the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for integration details. You can also explore example configurations in the [model repository](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "BFf6511E65ff"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Installs microWakeWord. Be sure to restart the session after this is finished.\n",
|
|
"import platform\n",
|
|
"import sys\n",
|
|
"import os\n",
|
|
"\n",
|
|
"if platform.system() == \"Darwin\":\n",
|
|
" # `pymicro-features` is installed from a fork to support building on macOS\n",
|
|
" !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n",
|
|
"\n",
|
|
"# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n",
|
|
"!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n",
|
|
"\n",
|
|
"# Clone the microWakeWord repository\n",
|
|
"repo_path = \"./microWakeWord\"\n",
|
|
"if not os.path.exists(repo_path):\n",
|
|
" print(\"Cloning microWakeWord repository...\")\n",
|
|
" !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n",
|
|
"\n",
|
|
"# Ensure the repository exists before attempting to install\n",
|
|
"if os.path.exists(repo_path):\n",
|
|
" print(\"Installing microWakeWord...\")\n",
|
|
" !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n",
|
|
"else:\n",
|
|
" print(f\"Repository not found at {repo_path}. Cloning might have failed.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "dEluu7nL7ywd"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generates 1 sample of the target word for manual verification.\n",
|
|
"\n",
|
|
"target_word = 'hey_norman' # Phonetic spellings may produce better samples\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import platform\n",
|
|
"\n",
|
|
"from IPython.display import Audio\n",
|
|
"\n",
|
|
"# Ensure the repository is cloned correctly\n",
|
|
"if not os.path.exists(\"./piper-sample-generator\"):\n",
|
|
" if platform.system() == \"Darwin\":\n",
|
|
" !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n",
|
|
" else:\n",
|
|
" !git clone https://github.com/rhasspy/piper-sample-generator\n",
|
|
"\n",
|
|
"# Download the required model\n",
|
|
"if not os.path.exists(\"piper-sample-generator/models/en_US-libritts_r-medium.pt\"):\n",
|
|
" !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n",
|
|
"\n",
|
|
"# Install system dependencies\n",
|
|
"!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1\n",
|
|
"\n",
|
|
"# Ensure the repository path is in sys.path\n",
|
|
"if \"piper-sample-generator/\" not in sys.path:\n",
|
|
" sys.path.append(\"piper-sample-generator/\")\n",
|
|
"\n",
|
|
"# Generate sample\n",
|
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
|
"--max-samples 1 \\\n",
|
|
"--batch-size 1 \\\n",
|
|
"--output-dir generated_samples\n",
|
|
"\n",
|
|
"# Play the generated audio sample\n",
|
|
"audio_path = \"generated_samples/0.wav\"\n",
|
|
"if os.path.exists(audio_path):\n",
|
|
" display(Audio(audio_path, autoplay=True))\n",
|
|
"else:\n",
|
|
" print(f\"Audio file not found at {audio_path}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "-SvGtCCM9akR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generates a larger amount of wake word samples.\n",
|
|
"# Start here when trying to improve your model.\n",
|
|
"# See https://github.com/rhasspy/-m piper-sample-generator for the full set of\n",
|
|
"# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n",
|
|
"# generating negative samples similar to the wake word, and generating many more\n",
|
|
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
|
"\n",
|
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
|
"--max-samples 10000 \\\n",
|
|
"--batch-size 100 \\\n",
|
|
"--output-dir generated_samples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "YJRG4Qvo9nXG"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Downloads audio data for augmentation. This can be slow!\n",
|
|
"# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n",
|
|
"#\n",
|
|
"# **Important note!** The data downloaded here has a mixture of difference\n",
|
|
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
|
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import scipy.io.wavfile\n",
|
|
"import numpy as np\n",
|
|
"import soundfile as sf\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm import tqdm\n",
|
|
"import requests\n",
|
|
"import tarfile\n",
|
|
"import zipfile\n",
|
|
"from datasets import load_dataset\n",
|
|
"\n",
|
|
"# Function to download and process RIR dataset\n",
|
|
"def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n",
|
|
" output_dir = Path(output_dir)\n",
|
|
" if not output_dir.exists():\n",
|
|
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" try:\n",
|
|
" rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n",
|
|
" print(f\"Downloading {dataset_name} to {output_dir}...\")\n",
|
|
" for row in tqdm(rir_dataset):\n",
|
|
" name = Path(row['audio']['path']).name\n",
|
|
" scipy.io.wavfile.write(\n",
|
|
" output_dir / name,\n",
|
|
" 16000,\n",
|
|
" (row['audio']['array'] * 32767).astype(np.int16)\n",
|
|
" )\n",
|
|
" print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error downloading {dataset_name}: {e}\")\n",
|
|
" else:\n",
|
|
" print(f\"{output_dir} already exists. Skipping download.\\n\")\n",
|
|
"\n",
|
|
"# Download MIT RIRs\n",
|
|
"download_rir_dataset(\n",
|
|
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
|
|
" \"./mit_rirs\"\n",
|
|
")\n",
|
|
"\n",
|
|
"# Function to download files\n",
|
|
"def download_file(url, output_path):\n",
|
|
" response = requests.get(url, stream=True)\n",
|
|
" total_size = int(response.headers.get('content-length', 0))\n",
|
|
" with open(output_path, \"wb\") as f, tqdm(\n",
|
|
" desc=f\"Downloading {output_path.name}\",\n",
|
|
" total=total_size,\n",
|
|
" unit=\"B\",\n",
|
|
" unit_scale=True,\n",
|
|
" unit_divisor=1024,\n",
|
|
" ) as bar:\n",
|
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
|
" f.write(chunk)\n",
|
|
" bar.update(len(chunk))\n",
|
|
" print(f\"Downloaded {output_path}\")\n",
|
|
"\n",
|
|
"# Function to extract .tar files\n",
|
|
"def extract_tar(file_path, extract_dir):\n",
|
|
" with tarfile.open(file_path, \"r\") as tar:\n",
|
|
" tar.extractall(path=extract_dir)\n",
|
|
" print(f\"Extracted {file_path} to {extract_dir}\")\n",
|
|
"\n",
|
|
"# Function to download and extract ZIP files\n",
|
|
"def download_and_extract_zip(url, extract_to):\n",
|
|
" file_name = url.split(\"/\")[-1]\n",
|
|
" local_path = Path(extract_to) / file_name\n",
|
|
" download_file(url, local_path)\n",
|
|
" with zipfile.ZipFile(local_path, 'r') as zip_ref:\n",
|
|
" zip_ref.extractall(extract_to)\n",
|
|
" print(f\"Extracted {file_name} to {extract_to}\")\n",
|
|
"\n",
|
|
"# Directories\n",
|
|
"raw_dir = Path(\"./audioset_raw\")\n",
|
|
"processed_dir = Path(\"./audioset_16k\")\n",
|
|
"raw_dir.mkdir(exist_ok=True)\n",
|
|
"processed_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"# Full-scale dataset download links\n",
|
|
"dataset_links = [\n",
|
|
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
|
" for i in range(10) # Adjust for additional parts\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Step 1: Download all parts of the dataset\n",
|
|
"print(\"Downloading datasets...\")\n",
|
|
"for link in dataset_links:\n",
|
|
" file_name = link.split(\"/\")[-1]\n",
|
|
" output_path = raw_dir / file_name\n",
|
|
" if not output_path.exists():\n",
|
|
" download_file(link, output_path)\n",
|
|
"\n",
|
|
"# Step 2: Extract all .tar files\n",
|
|
"print(\"Extracting datasets...\")\n",
|
|
"for file_path in raw_dir.glob(\"*.tar\"):\n",
|
|
" extract_dir = raw_dir / file_path.stem\n",
|
|
" if not extract_dir.exists():\n",
|
|
" extract_tar(file_path, extract_dir)\n",
|
|
"\n",
|
|
"# Step 3: Convert audio files to 16kHz WAV\n",
|
|
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n",
|
|
"print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
|
|
"if not audio_files:\n",
|
|
" raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n",
|
|
"\n",
|
|
"print(\"Converting audio files to 16kHz WAV...\")\n",
|
|
"corrupted_files = []\n",
|
|
"resampled_files = []\n",
|
|
"\n",
|
|
"for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n",
|
|
" try:\n",
|
|
" # Read the .flac file\n",
|
|
" data, samplerate = sf.read(file_path)\n",
|
|
"\n",
|
|
" # Check and resample if needed\n",
|
|
" if samplerate != 16000:\n",
|
|
" resampled_files.append(str(file_path))\n",
|
|
" data = np.interp(\n",
|
|
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
|
" np.arange(len(data)),\n",
|
|
" data,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Convert and save as WAV\n",
|
|
" output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n",
|
|
" scipy.io.wavfile.write(\n",
|
|
" output_path,\n",
|
|
" 16000,\n",
|
|
" (data * 32767).astype(np.int16),\n",
|
|
" )\n",
|
|
" except Exception as e:\n",
|
|
" corrupted_files.append(str(file_path))\n",
|
|
"\n",
|
|
"# Log corrupted files\n",
|
|
"if corrupted_files:\n",
|
|
" with open(\"corrupted_files.log\", \"w\") as log_file:\n",
|
|
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
|
"\n",
|
|
"# Log resampled files\n",
|
|
"if resampled_files:\n",
|
|
" with open(\"resampled_files.log\", \"w\") as log_file:\n",
|
|
" log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n",
|
|
"\n",
|
|
"print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n",
|
|
"print(f\"{len(resampled_files)} files resampled and logged.\")\n",
|
|
"\n",
|
|
"# Process fma_xs dataset\n",
|
|
"fma_raw_dir = Path(\"./fma\")\n",
|
|
"fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n",
|
|
"fma_raw_dir.mkdir(exist_ok=True)\n",
|
|
"fma_processed_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n",
|
|
"fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n",
|
|
"\n",
|
|
"if not fma_zip_path.exists():\n",
|
|
" print(\"Downloading fma_xs dataset...\")\n",
|
|
" download_file(fma_link, fma_zip_path)\n",
|
|
"\n",
|
|
"print(\"Extracting fma_xs dataset...\")\n",
|
|
"download_and_extract_zip(fma_link, fma_raw_dir)\n",
|
|
"\n",
|
|
"fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n",
|
|
"print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n",
|
|
"if not fma_audio_files:\n",
|
|
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
|
|
"\n",
|
|
"print(\"Converting fma_xs files to 16kHz WAV...\")\n",
|
|
"for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n",
|
|
" try:\n",
|
|
" # Read the .mp3 file\n",
|
|
" data, samplerate = sf.read(file_path)\n",
|
|
"\n",
|
|
" # Check and resample if needed\n",
|
|
" if samplerate != 16000:\n",
|
|
" data = np.interp(\n",
|
|
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
|
" np.arange(len(data)),\n",
|
|
" data,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Convert and save as WAV\n",
|
|
" output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n",
|
|
" scipy.io.wavfile.write(\n",
|
|
" output_path,\n",
|
|
" 16000,\n",
|
|
" (data * 32767).astype(np.int16),\n",
|
|
" )\n",
|
|
" except Exception as e:\n",
|
|
" corrupted_files.append(str(file_path))\n",
|
|
"\n",
|
|
"# Log corrupted files from fma_xs\n",
|
|
"if corrupted_files:\n",
|
|
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
|
|
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
|
"\n",
|
|
"print(\"fma_xs processing complete!\")\n",
|
|
"print(\"Full-scale dataset preparation complete!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "XW3bmbI5-JAz"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Sets up the augmentations.\n",
|
|
"# To improve your model, experiment with these settings and use more sources of\n",
|
|
"# background clips.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from microwakeword.audio.augmentation import Augmentation\n",
|
|
"from microwakeword.audio.clips import Clips\n",
|
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
|
"\n",
|
|
"def validate_directories(paths):\n",
|
|
" for path in paths:\n",
|
|
" if not os.path.exists(path):\n",
|
|
" print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n",
|
|
" return False\n",
|
|
" return True\n",
|
|
"\n",
|
|
"# Paths to augmented data\n",
|
|
"impulse_paths = ['mit_rirs']\n",
|
|
"background_paths = ['fma_16k', 'audioset_16k']\n",
|
|
"\n",
|
|
"if not validate_directories(impulse_paths + background_paths):\n",
|
|
" raise ValueError(\"One or more required directories are missing.\")\n",
|
|
"\n",
|
|
"clips = Clips(\n",
|
|
" input_directory='./generated_samples',\n",
|
|
" file_pattern='*.wav',\n",
|
|
" max_clip_duration_s=5,\n",
|
|
" remove_silence=True,\n",
|
|
" random_split_seed=10,\n",
|
|
" split_count=0.1,\n",
|
|
")\n",
|
|
"\n",
|
|
"augmenter = Augmentation(\n",
|
|
" augmentation_duration_s=3.2,\n",
|
|
" augmentation_probabilities={\n",
|
|
" \"SevenBandParametricEQ\": 0.1,\n",
|
|
" \"TanhDistortion\": 0.05,\n",
|
|
" \"PitchShift\": 0.15,\n",
|
|
" \"BandStopFilter\": 0.1,\n",
|
|
" \"AddColorNoise\": 0.1,\n",
|
|
" \"AddBackgroundNoise\": 0.9,\n",
|
|
" \"Gain\": 0.8,\n",
|
|
" \"RIR\": 0.7,\n",
|
|
" },\n",
|
|
" impulse_paths=impulse_paths,\n",
|
|
" background_paths=background_paths,\n",
|
|
" background_min_snr_db=-10,\n",
|
|
" background_max_snr_db=15,\n",
|
|
" min_jitter_s=0.15,\n",
|
|
" max_jitter_s=0.25,\n",
|
|
")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "V5UsJfKKD1k9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Augment a random clip and play it back to verify it works well\n",
|
|
"from pathlib import Path\n",
|
|
"from IPython.display import Audio\n",
|
|
"from microwakeword.audio.audio_utils import save_clip\n",
|
|
"\n",
|
|
"# Ensure output directory exists\n",
|
|
"output_dir = Path('./augmented_clips')\n",
|
|
"output_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"try:\n",
|
|
" # Get a random clip and apply augmentation\n",
|
|
" random_clip = clips.get_random_clip()\n",
|
|
" augmented_clip = augmenter.augment_clip(random_clip)\n",
|
|
" \n",
|
|
" # Save augmented clip to file\n",
|
|
" output_file = output_dir / 'augmented_clip.wav'\n",
|
|
" save_clip(augmented_clip, output_file)\n",
|
|
" print(f\"Augmented clip saved to {output_file}\")\n",
|
|
" \n",
|
|
" # Playback augmented clip\n",
|
|
" display(Audio(str(output_file), autoplay=True))\n",
|
|
"except Exception as e:\n",
|
|
" print(f\"Error during augmentation or playback: {e}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "D7BHcY1mEGbK"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Augment samples and save the training, validation, and testing sets.\n",
|
|
"# Validating and testing samples generated the same way can make the model\n",
|
|
"# benchmark better than it performs in real-word use. Use real samples or TTS\n",
|
|
"# samples generated with a different TTS engine to potentially get more accurate\n",
|
|
"# benchmarks.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from mmap_ninja.ragged import RaggedMmap\n",
|
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
|
"\n",
|
|
"# Output directory for augmented features\n",
|
|
"output_dir = 'generated_augmented_features'\n",
|
|
"os.makedirs(output_dir, exist_ok=True)\n",
|
|
"\n",
|
|
"# Configuration for each split\n",
|
|
"split_config = {\n",
|
|
" \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n",
|
|
" \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n",
|
|
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Generate augmented features for each split\n",
|
|
"for split, config in split_config.items():\n",
|
|
" out_dir = os.path.join(output_dir, split)\n",
|
|
" os.makedirs(out_dir, exist_ok=True)\n",
|
|
" print(f\"Processing {split} set...\")\n",
|
|
"\n",
|
|
" try:\n",
|
|
" # Spectrogram generation configuration\n",
|
|
" spectrograms = SpectrogramGeneration(\n",
|
|
" clips=clips,\n",
|
|
" augmenter=augmenter,\n",
|
|
" slide_frames=config[\"slide_frames\"],\n",
|
|
" step_ms=10, # Can parameterize this if needed\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Generate and save spectrogram features\n",
|
|
" RaggedMmap.from_generator(\n",
|
|
" out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n",
|
|
" sample_generator=spectrograms.spectrogram_generator(\n",
|
|
" split=config[\"name\"], repeat=config[\"repetition\"]\n",
|
|
" ),\n",
|
|
" batch_size=100, # Can parameterize this if needed\n",
|
|
" verbose=True,\n",
|
|
" )\n",
|
|
" print(f\"Completed processing {split} set. Output saved to {out_dir}\")\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing {split} set: {e}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "1pGuJDPyp3ax"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Downloads pre-generated spectrogram features (made for microWakeWord in\n",
|
|
"# particular) for various negative datasets. This can be slow!\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import requests\n",
|
|
"import zipfile\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"# Function to download a file with progress bar\n",
|
|
"def download_file(url, output_path):\n",
|
|
" response = requests.get(url, stream=True)\n",
|
|
" total_size = int(response.headers.get('content-length', 0))\n",
|
|
" with open(output_path, \"wb\") as f, tqdm(\n",
|
|
" desc=f\"Downloading {output_path.name}\",\n",
|
|
" total=total_size,\n",
|
|
" unit=\"B\",\n",
|
|
" unit_scale=True,\n",
|
|
" unit_divisor=1024,\n",
|
|
" ) as bar:\n",
|
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
|
" f.write(chunk)\n",
|
|
" bar.update(len(chunk))\n",
|
|
" print(f\"Downloaded: {output_path}\")\n",
|
|
"\n",
|
|
"# Function to extract ZIP files\n",
|
|
"def extract_zip(zip_path, extract_to):\n",
|
|
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
|
" zip_ref.extractall(extract_to)\n",
|
|
" print(f\"Extracted: {zip_path} to {extract_to}\")\n",
|
|
"\n",
|
|
"# Directory for negative datasets\n",
|
|
"output_dir = Path('./negative_datasets')\n",
|
|
"output_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"# Negative dataset URLs\n",
|
|
"link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n",
|
|
"filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n",
|
|
"\n",
|
|
"# Download and extract files\n",
|
|
"for fname in filenames:\n",
|
|
" link = link_root + fname\n",
|
|
" zip_path = output_dir / fname\n",
|
|
"\n",
|
|
" # Download only if the file doesn't already exist\n",
|
|
" if not zip_path.exists():\n",
|
|
" try:\n",
|
|
" download_file(link, zip_path)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error downloading {fname}: {e}\")\n",
|
|
" continue\n",
|
|
"\n",
|
|
" # Extract the ZIP file\n",
|
|
" try:\n",
|
|
" extract_zip(zip_path, output_dir)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error extracting {fname}: {e}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Ii1A14GsGVQT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save a yaml config that controls the training process\n",
|
|
"# These hyperparamters can make a huge different in model quality.\n",
|
|
"# Experiment with sampling and penalty weights and increasing the number of\n",
|
|
"# training steps.\n",
|
|
"\n",
|
|
"import yaml\n",
|
|
"import os\n",
|
|
"\n",
|
|
"config = {}\n",
|
|
"\n",
|
|
"config[\"window_step_ms\"] = 10\n",
|
|
"\n",
|
|
"config[\"train_dir\"] = \"trained_models/wakeword\"\n",
|
|
"\n",
|
|
"config[\"features\"] = [\n",
|
|
" {\n",
|
|
" \"features_dir\": \"generated_augmented_features\",\n",
|
|
" \"sampling_weight\": 5.0, # Increased\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": True,\n",
|
|
" \"truncation_strategy\": \"truncate_start\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/speech\",\n",
|
|
" \"sampling_weight\": 8.0, # Adjusted\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
|
" \"sampling_weight\": 8.0, # Adjusted\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
|
" \"sampling_weight\": 5.0, # Balanced\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
|
" \"sampling_weight\": 0.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"split\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
"]\n",
|
|
"\n",
|
|
"config[\"training_steps\"] = [20000] # Increased\n",
|
|
"config[\"positive_class_weight\"] = [1]\n",
|
|
"config[\"negative_class_weight\"] = [15] # Adjusted\n",
|
|
"config[\"learning_rates\"] = [0.0005] # Adjusted\n",
|
|
"config[\"batch_size\"] = 128\n",
|
|
"\n",
|
|
"config[\"time_mask_max_size\"] = [30] # Enabled SpecAugment\n",
|
|
"config[\"time_mask_count\"] = [2]\n",
|
|
"config[\"freq_mask_max_size\"] = [15]\n",
|
|
"config[\"freq_mask_count\"] = [2]\n",
|
|
"\n",
|
|
"config[\"eval_step_interval\"] = 1000 # Adjusted\n",
|
|
"config[\"clip_duration_ms\"] = 2000 # Increased\n",
|
|
"\n",
|
|
"config[\"target_minimization\"] = 0.9\n",
|
|
"config[\"minimization_metric\"] = None # Updated\n",
|
|
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
|
"\n",
|
|
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
|
" documents = yaml.dump(config, file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "WoEXJBaiC9mf"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Trains a model. When finished, it will quantize and convert the model to a\n",
|
|
"# streaming version suitable for on-device detection.\n",
|
|
"# It will resume if stopped, but it will start over at the configured training\n",
|
|
"# steps in the yaml file.\n",
|
|
"# Change --train 0 to only convert and test the best-weighted model.\n",
|
|
"# On Google colab, it doesn't print the mini-batch results, so it may appear\n",
|
|
"# stuck for several minutes! Additionally, it is very slow compared to training\n",
|
|
"# on a local GPU.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"\n",
|
|
"# Ensure the library path is correctly set\n",
|
|
"os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n",
|
|
"\n",
|
|
"# Training command with optimized settings\n",
|
|
"!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
|
|
"--training_config='training_parameters.yaml' \\\n",
|
|
"--train 1 \\\n",
|
|
"--restore_checkpoint 1 \\\n",
|
|
"--test_tf_nonstreaming 1 \\\n",
|
|
"--test_tflite_nonstreaming 1 \\\n",
|
|
"--test_tflite_nonstreaming_quantized 1 \\\n",
|
|
"--test_tflite_streaming 1 \\\n",
|
|
"--test_tflite_streaming_quantized 1 \\\n",
|
|
"--use_weights \"best_weights\" \\\n",
|
|
"mixednet \\\n",
|
|
"--pointwise_filters \"64,96,128,160\" \\\n",
|
|
"--repeat_in_block \"2,2,3,3\" \\\n",
|
|
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [17,23]' \\\n",
|
|
"--residual_connection \"1,1,1,0\" \\\n",
|
|
"--first_conv_filters 48 \\\n",
|
|
"--first_conv_kernel_size 7 \\\n",
|
|
"--stride 2 \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ex_UIWvwtjAN"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import shutil\n",
|
|
"import json\n",
|
|
"from IPython.display import FileLink, HTML\n",
|
|
"\n",
|
|
"# Copy the TFLite model file to the working directory\n",
|
|
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
|
"destination_path = \"./stream_state_internal_quant.tflite\"\n",
|
|
"shutil.copy(source_path, destination_path)\n",
|
|
"\n",
|
|
"# Define the JSON metadata for the model\n",
|
|
"json_data = {\n",
|
|
" \"type\": \"micro\",\n",
|
|
" \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n",
|
|
" \"author\": \"master phooey\",\n",
|
|
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
|
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
|
" \"trained_languages\": [\"en\"],\n",
|
|
" \"version\": 2,\n",
|
|
" \"micro\": {\n",
|
|
" \"probability_cutoff\": 0.97, # Threshold for wake word detection\n",
|
|
" \"sliding_window_size\": 5, # Frames averaged for predictions\n",
|
|
" \"feature_step_size\": 10,\n",
|
|
" \"tensor_arena_size\": 30000, # Memory allocation for TensorFlow Lite model\n",
|
|
" \"minimum_esphome_version\": \"2024.7.0\"\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Write the metadata to a JSON file\n",
|
|
"json_path = \"./stream_state_internal_quant.json\"\n",
|
|
"with open(json_path, \"w\") as json_file:\n",
|
|
" json.dump(json_data, json_file, indent=2)\n",
|
|
"\n",
|
|
"# Generate download links with styled HTML\n",
|
|
"tflite_link = FileLink(destination_path)\n",
|
|
"json_link = FileLink(json_path)\n",
|
|
"\n",
|
|
"html_content = f\"\"\"\n",
|
|
"<h3 style=\"color:orange;\">Your files are ready for download:</h3>\n",
|
|
"<ul>\n",
|
|
" <li><b><a href=\"{tflite_link.url}\" target=\"_blank\" style=\"color:orange;\">TFLite Model: stream_state_internal_quant.tflite</a></b></li>\n",
|
|
" <li><b><a href=\"{json_link.url}\" target=\"_blank\" style=\"color:orange;\">JSON Metadata: stream_state_internal_quant.json</a></b></li>\n",
|
|
"</ul>\n",
|
|
"<p style=\"font-size:12px; color:gray;\">Click the links to download the files. Ensure the files are moved to the correct directory for deployment.</p>\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"display(HTML(html_content))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "T4",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|