mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Add files via upload
This commit is contained in:
672
advanced_training_notebook.ipynb
Normal file
672
advanced_training_notebook.ipynb
Normal file
@@ -0,0 +1,672 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "r11cNiLqvWC6"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"<div align=\"center\">\n",
|
||||||
|
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
||||||
|
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
||||||
|
"</div>\n",
|
||||||
|
"\n",
|
||||||
|
"# Training a microWakeWord Model\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
|
||||||
|
"\n",
|
||||||
|
"**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**\n",
|
||||||
|
"\n",
|
||||||
|
"Throughout the notebook, you will find comments suggesting specific settings to modify and experiment with to enhance your model's performance.\n",
|
||||||
|
"\n",
|
||||||
|
"By the end of this notebook, you will have:\n",
|
||||||
|
"- A trained TensorFlow Lite model ready for deployment.\n",
|
||||||
|
"- A JSON manifest file to integrate the model with ESPHome.\n",
|
||||||
|
"\n",
|
||||||
|
"To use the generated model in ESPHome, refer to the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for integration details. You can also explore example configurations in the [model repository](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "BFf6511E65ff"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Installs microWakeWord. Be sure to restart the session after this is finished.\n",
|
||||||
|
"import platform\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"if platform.system() == \"Darwin\":\n",
|
||||||
|
" # `pymicro-features` is installed from a fork to support building on macOS\n",
|
||||||
|
" !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n",
|
||||||
|
"\n",
|
||||||
|
"# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n",
|
||||||
|
"!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n",
|
||||||
|
"\n",
|
||||||
|
"# Clone the microWakeWord repository\n",
|
||||||
|
"repo_path = \"./microWakeWord\"\n",
|
||||||
|
"if not os.path.exists(repo_path):\n",
|
||||||
|
" print(\"Cloning microWakeWord repository...\")\n",
|
||||||
|
" !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n",
|
||||||
|
"\n",
|
||||||
|
"# Ensure the repository exists before attempting to install\n",
|
||||||
|
"if os.path.exists(repo_path):\n",
|
||||||
|
" print(\"Installing microWakeWord...\")\n",
|
||||||
|
" !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(f\"Repository not found at {repo_path}. Cloning might have failed.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "dEluu7nL7ywd"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Generates 1 sample of the target word for manual verification.\n",
|
||||||
|
"\n",
|
||||||
|
"target_word = 'hey_norman' # Phonetic spellings may produce better samples\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import platform\n",
|
||||||
|
"\n",
|
||||||
|
"from IPython.display import Audio\n",
|
||||||
|
"\n",
|
||||||
|
"# Ensure the repository is cloned correctly\n",
|
||||||
|
"if not os.path.exists(\"./piper-sample-generator\"):\n",
|
||||||
|
" if platform.system() == \"Darwin\":\n",
|
||||||
|
" !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n",
|
||||||
|
" else:\n",
|
||||||
|
" !git clone https://github.com/rhasspy/piper-sample-generator\n",
|
||||||
|
"\n",
|
||||||
|
"# Download the required model\n",
|
||||||
|
"if not os.path.exists(\"piper-sample-generator/models/en_US-libritts_r-medium.pt\"):\n",
|
||||||
|
" !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n",
|
||||||
|
"\n",
|
||||||
|
"# Install system dependencies\n",
|
||||||
|
"!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1\n",
|
||||||
|
"\n",
|
||||||
|
"# Ensure the repository path is in sys.path\n",
|
||||||
|
"if \"piper-sample-generator/\" not in sys.path:\n",
|
||||||
|
" sys.path.append(\"piper-sample-generator/\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Generate sample\n",
|
||||||
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
||||||
|
"--max-samples 1 \\\n",
|
||||||
|
"--batch-size 1 \\\n",
|
||||||
|
"--output-dir generated_samples\n",
|
||||||
|
"\n",
|
||||||
|
"# Play the generated audio sample\n",
|
||||||
|
"audio_path = \"generated_samples/0.wav\"\n",
|
||||||
|
"if os.path.exists(audio_path):\n",
|
||||||
|
" display(Audio(audio_path, autoplay=True))\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(f\"Audio file not found at {audio_path}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "-SvGtCCM9akR"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Generates a larger amount of wake word samples.\n",
|
||||||
|
"# Start here when trying to improve your model.\n",
|
||||||
|
"# See https://github.com/rhasspy/-m piper-sample-generator for the full set of\n",
|
||||||
|
"# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n",
|
||||||
|
"# generating negative samples similar to the wake word, and generating many more\n",
|
||||||
|
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
||||||
|
"\n",
|
||||||
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
||||||
|
"--max-samples 10000 \\\n",
|
||||||
|
"--batch-size 100 \\\n",
|
||||||
|
"--output-dir generated_samples"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "YJRG4Qvo9nXG"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import scipy.io.wavfile\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from datasets import load_dataset, Audio\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from multiprocessing import Pool\n",
|
||||||
|
"import requests\n",
|
||||||
|
"import tarfile\n",
|
||||||
|
"import zipfile\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to download files\n",
|
||||||
|
"def download_file(url, output_path):\n",
|
||||||
|
" response = requests.get(url, stream=True)\n",
|
||||||
|
" with open(output_path, \"wb\") as f:\n",
|
||||||
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
||||||
|
" f.write(chunk)\n",
|
||||||
|
" print(f\"Downloaded {output_path}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to extract .tar files\n",
|
||||||
|
"def extract_tar(file_path, extract_dir):\n",
|
||||||
|
" with tarfile.open(file_path, \"r\") as tar:\n",
|
||||||
|
" tar.extractall(path=extract_dir)\n",
|
||||||
|
" print(f\"Extracted {file_path} to {extract_dir}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to download and extract ZIP files\n",
|
||||||
|
"def download_and_extract_zip(url, extract_to):\n",
|
||||||
|
" file_name = url.split(\"/\")[-1]\n",
|
||||||
|
" local_path = Path(extract_to) / file_name\n",
|
||||||
|
" download_file(url, local_path)\n",
|
||||||
|
" with zipfile.ZipFile(local_path, 'r') as zip_ref:\n",
|
||||||
|
" zip_ref.extractall(extract_to)\n",
|
||||||
|
" print(f\"Extracted {file_name} to {extract_to}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to convert audio files to 16kHz WAV format\n",
|
||||||
|
"def convert_audio(file_path, output_dir):\n",
|
||||||
|
" output_file = output_dir / file_path.name.replace(\".flac\", \".wav\")\n",
|
||||||
|
" audio = Audio(sampling_rate=16000).decode_example({\"path\": str(file_path)})\n",
|
||||||
|
" scipy.io.wavfile.write(\n",
|
||||||
|
" output_file, 16000, (audio[\"array\"] * 32767).astype(np.int16)\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"# Directories\n",
|
||||||
|
"raw_dir = Path(\"./audioset_raw\")\n",
|
||||||
|
"processed_dir = Path(\"./audioset_16k\")\n",
|
||||||
|
"raw_dir.mkdir(exist_ok=True)\n",
|
||||||
|
"processed_dir.mkdir(exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Full-scale dataset download links\n",
|
||||||
|
"dataset_links = [\n",
|
||||||
|
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
||||||
|
" for i in range(10) # Adjust for additional parts\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"# Step 1: Download all parts of the dataset\n",
|
||||||
|
"for link in dataset_links:\n",
|
||||||
|
" file_name = link.split(\"/\")[-1]\n",
|
||||||
|
" output_path = raw_dir / file_name\n",
|
||||||
|
" if not output_path.exists():\n",
|
||||||
|
" print(f\"Downloading {file_name}...\")\n",
|
||||||
|
" download_file(link, output_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# Step 2: Extract all .tar files\n",
|
||||||
|
"for file_path in raw_dir.glob(\"*.tar\"):\n",
|
||||||
|
" extract_dir = raw_dir / file_path.stem\n",
|
||||||
|
" if not extract_dir.exists():\n",
|
||||||
|
" print(f\"Extracting {file_path}...\")\n",
|
||||||
|
" extract_tar(file_path, extract_dir)\n",
|
||||||
|
"\n",
|
||||||
|
"# Step 3: Convert audio files to 16kHz WAV\n",
|
||||||
|
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n",
|
||||||
|
"\n",
|
||||||
|
"def process_audio(file_path):\n",
|
||||||
|
" convert_audio(file_path, processed_dir)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Converting audio files to 16kHz WAV...\")\n",
|
||||||
|
"with Pool() as pool:\n",
|
||||||
|
" list(tqdm(pool.imap(process_audio, audio_files), total=len(audio_files)))\n",
|
||||||
|
"\n",
|
||||||
|
"# Optional: Download and process additional datasets\n",
|
||||||
|
"additional_datasets = {\n",
|
||||||
|
" \"fsd50k\": \"https://zenodo.org/record/4060432/files/FSD50K.dev_audio.zip\",\n",
|
||||||
|
" \"fma_xs\": \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\",\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"for dataset_name, link in additional_datasets.items():\n",
|
||||||
|
" dataset_dir = Path(f\"./{dataset_name}\")\n",
|
||||||
|
" dataset_dir.mkdir(exist_ok=True)\n",
|
||||||
|
" try:\n",
|
||||||
|
" download_and_extract_zip(link, dataset_dir)\n",
|
||||||
|
" # Add specific processing logic for each dataset if required\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error processing {dataset_name}: {e}\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Full-scale dataset preparation complete!\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "XW3bmbI5-JAz"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Sets up the augmentations.\n",
|
||||||
|
"# To improve your model, experiment with these settings and use more sources of\n",
|
||||||
|
"# background clips.\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"from microwakeword.audio.augmentation import Augmentation\n",
|
||||||
|
"from microwakeword.audio.clips import Clips\n",
|
||||||
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||||||
|
"\n",
|
||||||
|
"def validate_directories(paths):\n",
|
||||||
|
" for path in paths:\n",
|
||||||
|
" if not os.path.exists(path):\n",
|
||||||
|
" print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n",
|
||||||
|
" return False\n",
|
||||||
|
" return True\n",
|
||||||
|
"\n",
|
||||||
|
"# Paths to augmented data\n",
|
||||||
|
"impulse_paths = ['mit_rirs', 'openair_rirs']\n",
|
||||||
|
"background_paths = ['fma_16k', 'audioset_16k']\n",
|
||||||
|
"\n",
|
||||||
|
"if not validate_directories(impulse_paths + background_paths):\n",
|
||||||
|
" raise ValueError(\"One or more required directories are missing.\")\n",
|
||||||
|
"\n",
|
||||||
|
"clips = Clips(\n",
|
||||||
|
" input_directory='./generated_samples',\n",
|
||||||
|
" file_pattern='*.wav',\n",
|
||||||
|
" max_clip_duration_s=5,\n",
|
||||||
|
" remove_silence=True,\n",
|
||||||
|
" random_split_seed=10,\n",
|
||||||
|
" split_count=0.1,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"augmenter = Augmentation(\n",
|
||||||
|
" augmentation_duration_s=3.2,\n",
|
||||||
|
" augmentation_probabilities={\n",
|
||||||
|
" \"SevenBandParametricEQ\": 0.1,\n",
|
||||||
|
" \"TanhDistortion\": 0.05,\n",
|
||||||
|
" \"PitchShift\": 0.15,\n",
|
||||||
|
" \"BandStopFilter\": 0.1,\n",
|
||||||
|
" \"AddColorNoise\": 0.1,\n",
|
||||||
|
" \"AddBackgroundNoise\": 0.9,\n",
|
||||||
|
" \"Gain\": 0.8,\n",
|
||||||
|
" \"RIR\": 0.7,\n",
|
||||||
|
" },\n",
|
||||||
|
" impulse_paths=impulse_paths,\n",
|
||||||
|
" background_paths=background_paths,\n",
|
||||||
|
" background_min_snr_db=-10,\n",
|
||||||
|
" background_max_snr_db=15,\n",
|
||||||
|
" min_jitter_s=0.15,\n",
|
||||||
|
" max_jitter_s=0.25,\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "V5UsJfKKD1k9"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Augment a random clip and play it back to verify it works well\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from IPython.display import Audio\n",
|
||||||
|
"from microwakeword.audio.audio_utils import save_clip\n",
|
||||||
|
"\n",
|
||||||
|
"# Ensure output directory exists\n",
|
||||||
|
"output_dir = Path('./augmented_clips')\n",
|
||||||
|
"output_dir.mkdir(exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"try:\n",
|
||||||
|
" # Get a random clip and apply augmentation\n",
|
||||||
|
" random_clip = clips.get_random_clip()\n",
|
||||||
|
" augmented_clip = augmenter.augment_clip(random_clip)\n",
|
||||||
|
" \n",
|
||||||
|
" # Save augmented clip to file\n",
|
||||||
|
" output_file = output_dir / 'augmented_clip.wav'\n",
|
||||||
|
" save_clip(augmented_clip, output_file)\n",
|
||||||
|
" print(f\"Augmented clip saved to {output_file}\")\n",
|
||||||
|
" \n",
|
||||||
|
" # Playback augmented clip\n",
|
||||||
|
" display(Audio(str(output_file), autoplay=True))\n",
|
||||||
|
"except Exception as e:\n",
|
||||||
|
" print(f\"Error during augmentation or playback: {e}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "D7BHcY1mEGbK"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Augment samples and save the training, validation, and testing sets.\n",
|
||||||
|
"# Validating and testing samples generated the same way can make the model\n",
|
||||||
|
"# benchmark better than it performs in real-word use. Use real samples or TTS\n",
|
||||||
|
"# samples generated with a different TTS engine to potentially get more accurate\n",
|
||||||
|
"# benchmarks.\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"from mmap_ninja.ragged import RaggedMmap\n",
|
||||||
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||||||
|
"\n",
|
||||||
|
"# Output directory for augmented features\n",
|
||||||
|
"output_dir = 'generated_augmented_features'\n",
|
||||||
|
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Configuration for each split\n",
|
||||||
|
"split_config = {\n",
|
||||||
|
" \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n",
|
||||||
|
" \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n",
|
||||||
|
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# Generate augmented features for each split\n",
|
||||||
|
"for split, config in split_config.items():\n",
|
||||||
|
" out_dir = os.path.join(output_dir, split)\n",
|
||||||
|
" os.makedirs(out_dir, exist_ok=True)\n",
|
||||||
|
" print(f\"Processing {split} set...\")\n",
|
||||||
|
"\n",
|
||||||
|
" try:\n",
|
||||||
|
" # Spectrogram generation configuration\n",
|
||||||
|
" spectrograms = SpectrogramGeneration(\n",
|
||||||
|
" clips=clips,\n",
|
||||||
|
" augmenter=augmenter,\n",
|
||||||
|
" slide_frames=config[\"slide_frames\"],\n",
|
||||||
|
" step_ms=10, # Can parameterize this if needed\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Generate and save spectrogram features\n",
|
||||||
|
" RaggedMmap.from_generator(\n",
|
||||||
|
" out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n",
|
||||||
|
" sample_generator=spectrograms.spectrogram_generator(\n",
|
||||||
|
" split=config[\"name\"], repeat=config[\"repetition\"]\n",
|
||||||
|
" ),\n",
|
||||||
|
" batch_size=100, # Can parameterize this if needed\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" print(f\"Completed processing {split} set. Output saved to {out_dir}\")\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error processing {split} set: {e}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "1pGuJDPyp3ax"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Downloads pre-generated spectrogram features (made for microWakeWord in\n",
|
||||||
|
"# particular) for various negative datasets. This can be slow!\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"import requests\n",
|
||||||
|
"import zipfile\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to download a file with progress bar\n",
|
||||||
|
"def download_file(url, output_path):\n",
|
||||||
|
" response = requests.get(url, stream=True)\n",
|
||||||
|
" total_size = int(response.headers.get('content-length', 0))\n",
|
||||||
|
" with open(output_path, \"wb\") as f, tqdm(\n",
|
||||||
|
" desc=f\"Downloading {output_path.name}\",\n",
|
||||||
|
" total=total_size,\n",
|
||||||
|
" unit=\"B\",\n",
|
||||||
|
" unit_scale=True,\n",
|
||||||
|
" unit_divisor=1024,\n",
|
||||||
|
" ) as bar:\n",
|
||||||
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
||||||
|
" f.write(chunk)\n",
|
||||||
|
" bar.update(len(chunk))\n",
|
||||||
|
" print(f\"Downloaded: {output_path}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to extract ZIP files\n",
|
||||||
|
"def extract_zip(zip_path, extract_to):\n",
|
||||||
|
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
||||||
|
" zip_ref.extractall(extract_to)\n",
|
||||||
|
" print(f\"Extracted: {zip_path} to {extract_to}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Directory for negative datasets\n",
|
||||||
|
"output_dir = Path('./negative_datasets')\n",
|
||||||
|
"output_dir.mkdir(exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Negative dataset URLs\n",
|
||||||
|
"link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n",
|
||||||
|
"filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n",
|
||||||
|
"\n",
|
||||||
|
"# Download and extract files\n",
|
||||||
|
"for fname in filenames:\n",
|
||||||
|
" link = link_root + fname\n",
|
||||||
|
" zip_path = output_dir / fname\n",
|
||||||
|
"\n",
|
||||||
|
" # Download only if the file doesn't already exist\n",
|
||||||
|
" if not zip_path.exists():\n",
|
||||||
|
" try:\n",
|
||||||
|
" download_file(link, zip_path)\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error downloading {fname}: {e}\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" # Extract the ZIP file\n",
|
||||||
|
" try:\n",
|
||||||
|
" extract_zip(zip_path, output_dir)\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error extracting {fname}: {e}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "Ii1A14GsGVQT"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Save a yaml config that controls the training process\n",
|
||||||
|
"# These hyperparamters can make a huge different in model quality.\n",
|
||||||
|
"# Experiment with sampling and penalty weights and increasing the number of\n",
|
||||||
|
"# training steps.\n",
|
||||||
|
"\n",
|
||||||
|
"import yaml\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"config = {}\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"window_step_ms\"] = 10\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"train_dir\"] = \"trained_models/wakeword\"\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"features\"] = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"features_dir\": \"generated_augmented_features\",\n",
|
||||||
|
" \"sampling_weight\": 5.0, # Increased\n",
|
||||||
|
" \"penalty_weight\": 1.0,\n",
|
||||||
|
" \"truth\": True,\n",
|
||||||
|
" \"truncation_strategy\": \"truncate_start\",\n",
|
||||||
|
" \"type\": \"mmap\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"features_dir\": \"negative_datasets/speech\",\n",
|
||||||
|
" \"sampling_weight\": 8.0, # Adjusted\n",
|
||||||
|
" \"penalty_weight\": 1.0,\n",
|
||||||
|
" \"truth\": False,\n",
|
||||||
|
" \"truncation_strategy\": \"random\",\n",
|
||||||
|
" \"type\": \"mmap\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
||||||
|
" \"sampling_weight\": 8.0, # Adjusted\n",
|
||||||
|
" \"penalty_weight\": 1.0,\n",
|
||||||
|
" \"truth\": False,\n",
|
||||||
|
" \"truncation_strategy\": \"random\",\n",
|
||||||
|
" \"type\": \"mmap\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
||||||
|
" \"sampling_weight\": 5.0, # Balanced\n",
|
||||||
|
" \"penalty_weight\": 1.0,\n",
|
||||||
|
" \"truth\": False,\n",
|
||||||
|
" \"truncation_strategy\": \"random\",\n",
|
||||||
|
" \"type\": \"mmap\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
||||||
|
" \"sampling_weight\": 0.0,\n",
|
||||||
|
" \"penalty_weight\": 1.0,\n",
|
||||||
|
" \"truth\": False,\n",
|
||||||
|
" \"truncation_strategy\": \"split\",\n",
|
||||||
|
" \"type\": \"mmap\",\n",
|
||||||
|
" },\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"training_steps\"] = [20000] # Increased\n",
|
||||||
|
"config[\"positive_class_weight\"] = [1]\n",
|
||||||
|
"config[\"negative_class_weight\"] = [15] # Adjusted\n",
|
||||||
|
"config[\"learning_rates\"] = [0.0005] # Adjusted\n",
|
||||||
|
"config[\"batch_size\"] = 128\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"time_mask_max_size\"] = [30] # Enabled SpecAugment\n",
|
||||||
|
"config[\"time_mask_count\"] = [2]\n",
|
||||||
|
"config[\"freq_mask_max_size\"] = [15]\n",
|
||||||
|
"config[\"freq_mask_count\"] = [2]\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"eval_step_interval\"] = 1000 # Adjusted\n",
|
||||||
|
"config[\"clip_duration_ms\"] = 2000 # Increased\n",
|
||||||
|
"\n",
|
||||||
|
"config[\"target_minimization\"] = 0.9\n",
|
||||||
|
"config[\"minimization_metric\"] = \"false_positive_rate\" # Updated\n",
|
||||||
|
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
||||||
|
"\n",
|
||||||
|
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
||||||
|
" documents = yaml.dump(config, file)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "WoEXJBaiC9mf"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Trains a model. When finished, it will quantize and convert the model to a\n",
|
||||||
|
"# streaming version suitable for on-device detection.\n",
|
||||||
|
"# It will resume if stopped, but it will start over at the configured training\n",
|
||||||
|
"# steps in the yaml file.\n",
|
||||||
|
"# Change --train 0 to only convert and test the best-weighted model.\n",
|
||||||
|
"# On Google colab, it doesn't print the mini-batch results, so it may appear\n",
|
||||||
|
"# stuck for several minutes! Additionally, it is very slow compared to training\n",
|
||||||
|
"# on a local GPU.\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"import sys\n",
|
||||||
|
"\n",
|
||||||
|
"# Ensure the library path is correctly set\n",
|
||||||
|
"os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n",
|
||||||
|
"\n",
|
||||||
|
"# Training command with optimized settings\n",
|
||||||
|
"!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
|
||||||
|
"--training_config='training_parameters.yaml' \\\n",
|
||||||
|
"--train 1 \\ # Enable training\n",
|
||||||
|
"--restore_checkpoint 1 \\ # Resume training from the last checkpoint\n",
|
||||||
|
"--test_tf_nonstreaming 1 \\ # Test TensorFlow non-streaming model\n",
|
||||||
|
"--test_tflite_nonstreaming 1 \\ # Test TFLite non-streaming model\n",
|
||||||
|
"--test_tflite_nonstreaming_quantized 1 \\ # Test TFLite quantized non-streaming model\n",
|
||||||
|
"--test_tflite_streaming 1 \\ # Test TFLite streaming model\n",
|
||||||
|
"--test_tflite_streaming_quantized 1 \\ # Test TFLite quantized streaming model\n",
|
||||||
|
"--use_weights \"best_weights\" \\ # Use the best model weights for testing\n",
|
||||||
|
"mixednet \\ # Specify the model architecture\n",
|
||||||
|
"--pointwise_filters \"64,96,128,160\" \\ # Optimized filter sizes\n",
|
||||||
|
"--repeat_in_block \"2,2,3,3\" \\ # Increased repetitions for deeper feature learning\n",
|
||||||
|
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [17,23]' \\ # Wider kernels for better temporal feature extraction\n",
|
||||||
|
"--residual_connection \"1,1,1,0\" \\ # Enable residuals for all but the last block\n",
|
||||||
|
"--first_conv_filters 48 \\ # Larger initial filter size for improved feature extraction\n",
|
||||||
|
"--first_conv_kernel_size 7 \\ # Larger kernel for the first convolution\n",
|
||||||
|
"--stride 2 # Reduce stride to preserve more temporal details\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "ex_UIWvwtjAN"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import shutil\n",
|
||||||
|
"import json\n",
|
||||||
|
"from IPython.display import FileLink, HTML\n",
|
||||||
|
"\n",
|
||||||
|
"# Copy the TFLite model file to the working directory\n",
|
||||||
|
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
||||||
|
"destination_path = \"./stream_state_internal_quant.tflite\"\n",
|
||||||
|
"shutil.copy(source_path, destination_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# Define the JSON metadata for the model\n",
|
||||||
|
"json_data = {\n",
|
||||||
|
" \"type\": \"micro\",\n",
|
||||||
|
" \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n",
|
||||||
|
" \"author\": \"master phooey\",\n",
|
||||||
|
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
||||||
|
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
||||||
|
" \"trained_languages\": [\"en\"],\n",
|
||||||
|
" \"version\": 2,\n",
|
||||||
|
" \"micro\": {\n",
|
||||||
|
" \"probability_cutoff\": 0.97, # Threshold for wake word detection\n",
|
||||||
|
" \"sliding_window_size\": 5, # Frames averaged for predictions\n",
|
||||||
|
" \"feature_step_size\": 10,\n",
|
||||||
|
" \"tensor_arena_size\": 30000, # Memory allocation for TensorFlow Lite model\n",
|
||||||
|
" \"minimum_esphome_version\": \"2024.7.0\"\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# Write the metadata to a JSON file\n",
|
||||||
|
"json_path = \"./stream_state_internal_quant.json\"\n",
|
||||||
|
"with open(json_path, \"w\") as json_file:\n",
|
||||||
|
" json.dump(json_data, json_file, indent=2)\n",
|
||||||
|
"\n",
|
||||||
|
"# Generate download links with styled HTML\n",
|
||||||
|
"tflite_link = FileLink(destination_path)\n",
|
||||||
|
"json_link = FileLink(json_path)\n",
|
||||||
|
"\n",
|
||||||
|
"html_content = f\"\"\"\n",
|
||||||
|
"<h3 style=\"color:orange;\">Your files are ready for download:</h3>\n",
|
||||||
|
"<ul>\n",
|
||||||
|
" <li><b><a href=\"{tflite_link.url}\" target=\"_blank\" style=\"color:orange;\">TFLite Model: stream_state_internal_quant.tflite</a></b></li>\n",
|
||||||
|
" <li><b><a href=\"{json_link.url}\" target=\"_blank\" style=\"color:orange;\">JSON Metadata: stream_state_internal_quant.json</a></b></li>\n",
|
||||||
|
"</ul>\n",
|
||||||
|
"<p style=\"font-size:12px; color:gray;\">Click the links to download the files. Ensure the files are moved to the correct directory for deployment.</p>\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"display(HTML(html_content))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "T4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user