{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "r11cNiLqvWC6" }, "source": [ "# Training a microWakeWord Model\n", "\n", "This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.\n", "\n", "**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**\n", "\n", "In the comment at the start of certain blocks, I note some specific settings to consider modifying.\n", "\n", "This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!\n", "\n", "At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BFf6511E65ff" }, "outputs": [], "source": [ "# Installs microWakeWord. Be sure to restart the session after this is finished.\n", "import platform\n", "import sys\n", "import os\n", "\n", "if platform.system() == \"Darwin\":\n", " # `pymicro-features` is installed from a fork to support building on macOS\n", " !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'\n", "\n", "# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n", "!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'\n", "\n", "# Clone the microWakeWord repository\n", "repo_path = \"./microWakeWord\"\n", "if not os.path.exists(repo_path):\n", " print(\"Cloning microWakeWord repository...\")\n", " !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n", "\n", "# Ensure the repository exists before attempting to install\n", "if os.path.exists(repo_path):\n", " print(\"Installing microWakeWord...\")\n", " !\"{sys.executable}\" -m pip install -e {repo_path}\n", "else:\n", " print(f\"Repository not found at {repo_path}. Cloning might have failed.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dEluu7nL7ywd" }, "outputs": [], "source": [ "# Generates 1 sample of the target word for manual verification.\n", "\n", "target_word = 'khum_puter' # Phonetic spellings may produce better samples\n", "\n", "import os\n", "import sys\n", "import platform\n", "\n", "from IPython.display import Audio\n", "\n", "# Ensure the repository is cloned correctly\n", "if not os.path.exists(\"./piper-sample-generator\"):\n", " if platform.system() == \"Darwin\":\n", " !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n", " else:\n", " !git clone https://github.com/rhasspy/piper-sample-generator\n", "\n", "# Download the required model\n", "if not os.path.exists(\"piper-sample-generator/models/en_US-libritts_r-medium.pt\"):\n", " !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n", "\n", "# Install system dependencies\n", "!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1\n", "\n", "# Ensure the repository path is in sys.path\n", "if \"piper-sample-generator/\" not in sys.path:\n", " sys.path.append(\"piper-sample-generator/\")\n", "\n", "# Generate sample\n", "!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", "--max-samples 1 \\\n", "--batch-size 1 \\\n", "--output-dir generated_samples\n", "\n", "# Play the generated audio sample\n", "audio_path = \"generated_samples/0.wav\"\n", "if os.path.exists(audio_path):\n", " display(Audio(audio_path, autoplay=True))\n", "else:\n", " print(f\"Audio file not found at {audio_path}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-SvGtCCM9akR" }, "outputs": [], "source": [ "# Generates a larger amount of wake word samples.\n", "# Start here when trying to improve your model.\n", "# See https://github.com/rhasspy/-m piper-sample-generator for the full set of\n", "# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n", "# generating negative samples similar to the wake word, and generating many more\n", "# wake word samples, possibly with different phonetic pronunciations.\n", "\n", "!\"{sys.executable}\" -m piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", "--max-samples 1000 \\\n", "--batch-size 100 \\\n", "--output-dir generated_samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YJRG4Qvo9nXG" }, "outputs": [], "source": [ "# Downloads audio data for augmentation. This can be slow!\n", "# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n", "#\n", "# **Important note!** The data downloaded here has a mixture of difference\n", "# licenses and usage restrictions. As such, any custom models trained with this\n", "# data should be considered as appropriate for **non-commercial** personal use only.\n", "\n", "\n", "import datasets\n", "import scipy\n", "import os\n", "\n", "import numpy as np\n", "\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "\n", "## Download MIR RIR data\n", "\n", "output_dir = \"./mit_rirs\"\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", " # Save clips to 16-bit PCM wav files\n", " for row in tqdm(rir_dataset):\n", " name = row['audio']['path'].split('/')[-1]\n", " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", "\n", "## Download noise and background audio\n", "\n", "# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n", "# Download one part of the audioset .tar files, extract, and convert to 16khz\n", "# For full-scale training, it's recommended to download the entire dataset from\n", "# https://huggingface.co/datasets/agkphysics/AudioSet, and\n", "# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n", "\n", "if not os.path.exists(\"audioset\"):\n", " os.mkdir(\"audioset\")\n", "\n", " fname = \"bal_train09.tar\"\n", " out_dir = f\"audioset/{fname}\"\n", " link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n", " !wget -O {out_dir} {link}\n", " !cd audioset && tar -xf bal_train09.tar\n", "\n", " output_dir = \"./audioset_16k\"\n", " if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", "\n", " # Save clips to 16-bit PCM wav files\n", " audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n", " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", " for row in tqdm(audioset_dataset):\n", " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", "\n", "# Free Music Archive dataset\n", "# https://github.com/mdeff/fma\n", "# (Third-party mchl914 extra small set)\n", "\n", "output_dir = \"./fma\"\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", " fname = \"fma_xs.zip\"\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", " out_dir = f\"fma/{fname}\"\n", " !wget -O {out_dir} {link}\n", " !cd {output_dir} && unzip -q {fname}\n", "\n", " output_dir = \"./fma_16k\"\n", " if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", "\n", " # Save clips to 16-bit PCM wav files\n", " fma_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"fma/fma_small\").glob(\"**/*.mp3\")]})\n", " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", " for row in tqdm(audioset_dataset):\n", " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XW3bmbI5-JAz" }, "outputs": [], "source": [ "# Sets up the augmentations.\n", "# To improve your model, experiment with these settings and use more sources of\n", "# background clips.\n", "\n", "from microwakeword.audio.augmentation import Augmentation\n", "from microwakeword.audio.clips import Clips\n", "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "\n", "clips = Clips(input_directory='generated_samples',\n", " file_pattern='*.wav',\n", " max_clip_duration_s=None,\n", " remove_silence=False,\n", " random_split_seed=10,\n", " split_count=0.1,\n", " )\n", "augmenter = Augmentation(augmentation_duration_s=3.2,\n", " augmentation_probabilities = {\n", " \"SevenBandParametricEQ\": 0.1,\n", " \"TanhDistortion\": 0.1,\n", " \"PitchShift\": 0.1,\n", " \"BandStopFilter\": 0.1,\n", " \"AddColorNoise\": 0.1,\n", " \"AddBackgroundNoise\": 0.75,\n", " \"Gain\": 1.0,\n", " \"RIR\": 0.5,\n", " },\n", " impulse_paths = ['mit_rirs'],\n", " background_paths = ['fma_16k', 'audioset_16k'],\n", " background_min_snr_db = -5,\n", " background_max_snr_db = 10,\n", " min_jitter_s = 0.195,\n", " max_jitter_s = 0.205,\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V5UsJfKKD1k9" }, "outputs": [], "source": [ "# Augment a random clip and play it back to verify it works well\n", "\n", "from IPython.display import Audio\n", "from microwakeword.audio.audio_utils import save_clip\n", "\n", "random_clip = clips.get_random_clip()\n", "augmented_clip = augmenter.augment_clip(random_clip)\n", "save_clip(augmented_clip, 'augmented_clip.wav')\n", "\n", "Audio(\"augmented_clip.wav\", autoplay=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D7BHcY1mEGbK" }, "outputs": [], "source": [ "# Augment samples and save the training, validation, and testing sets.\n", "# Validating and testing samples generated the same way can make the model\n", "# benchmark better than it performs in real-word use. Use real samples or TTS\n", "# samples generated with a different TTS engine to potentially get more accurate\n", "# benchmarks.\n", "\n", "import os\n", "from mmap_ninja.ragged import RaggedMmap\n", "\n", "output_dir = 'generated_augmented_features'\n", "\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", "\n", "splits = [\"training\", \"validation\", \"testing\"]\n", "for split in splits:\n", " out_dir = os.path.join(output_dir, split)\n", " if not os.path.exists(out_dir):\n", " os.mkdir(out_dir)\n", "\n", "\n", " split_name = \"train\"\n", " repetition = 2\n", "\n", " spectrograms = SpectrogramGeneration(clips=clips,\n", " augmenter=augmenter,\n", " slide_frames=10, # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.\n", " step_ms=10,\n", " )\n", " if split == \"validation\":\n", " split_name = \"validation\"\n", " repetition = 1\n", " elif split == \"testing\":\n", " split_name = \"test\"\n", " repetition = 1\n", " spectrograms = SpectrogramGeneration(clips=clips,\n", " augmenter=augmenter,\n", " slide_frames=1, # The testing set uses the streaming version of the model, so no artificial repetition is necessary\n", " step_ms=10,\n", " )\n", "\n", " RaggedMmap.from_generator(\n", " out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n", " sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),\n", " batch_size=100,\n", " verbose=True,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1pGuJDPyp3ax" }, "outputs": [], "source": [ "# Downloads pre-generated spectrogram features (made for microWakeWord in\n", "# particular) for various negative datasets. This can be slow!\n", "\n", "output_dir = './negative_datasets'\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", " link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n", " filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n", " for fname in filenames:\n", " link = link_root + fname\n", "\n", " zip_path = f\"negative_datasets/{fname}\"\n", " !wget -O {zip_path} {link}\n", " !unzip -q {zip_path} -d {output_dir}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ii1A14GsGVQT" }, "outputs": [], "source": [ "# Save a yaml config that controls the training process\n", "# These hyperparamters can make a huge different in model quality.\n", "# Experiment with sampling and penalty weights and increasing the number of\n", "# training steps.\n", "\n", "import yaml\n", "import os\n", "\n", "config = {}\n", "\n", "config[\"window_step_ms\"] = 10\n", "\n", "config[\"train_dir\"] = (\n", " \"trained_models/wakeword\"\n", ")\n", "\n", "\n", "# Each feature_dir should have at least one of the following folders with this structure:\n", "# training/\n", "# ragged_mmap_folders_ending_in_mmap\n", "# testing/\n", "# ragged_mmap_folders_ending_in_mmap\n", "# testing_ambient/\n", "# ragged_mmap_folders_ending_in_mmap\n", "# validation/\n", "# ragged_mmap_folders_ending_in_mmap\n", "# validation_ambient/\n", "# ragged_mmap_folders_ending_in_mmap\n", "#\n", "# sampling_weight: Weight for choosing a spectrogram from this set in the batch\n", "# penalty_weight: Penalizing weight for incorrect predictions from this set\n", "# truth: Boolean whether this set has positive samples or negative samples\n", "# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated\n", "# - random: choose a random portion of the entire spectrogram - useful for long negative samples\n", "# - truncate_start: remove the start of the spectrogram\n", "# - truncate_end: remove the end of the spectrogram\n", "# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets\n", "\n", "config[\"features\"] = [\n", " {\n", " \"features_dir\": \"generated_augmented_features\",\n", " \"sampling_weight\": 2.0,\n", " \"penalty_weight\": 1.0,\n", " \"truth\": True,\n", " \"truncation_strategy\": \"truncate_start\",\n", " \"type\": \"mmap\",\n", " },\n", " {\n", " \"features_dir\": \"negative_datasets/speech\",\n", " \"sampling_weight\": 10.0,\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"random\",\n", " \"type\": \"mmap\",\n", " },\n", " {\n", " \"features_dir\": \"negative_datasets/dinner_party\",\n", " \"sampling_weight\": 10.0,\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"random\",\n", " \"type\": \"mmap\",\n", " },\n", " {\n", " \"features_dir\": \"negative_datasets/no_speech\",\n", " \"sampling_weight\": 5.0,\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"random\",\n", " \"type\": \"mmap\",\n", " },\n", " { # Only used for validation and testing\n", " \"features_dir\": \"negative_datasets/dinner_party_eval\",\n", " \"sampling_weight\": 0.0,\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"split\",\n", " \"type\": \"mmap\",\n", " },\n", "]\n", "\n", "# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps\n", "config[\"training_steps\"] = [10000]\n", "\n", "# Penalizing weight for incorrect class predictions - lists that correspond to training steps\n", "config[\"positive_class_weight\"] = [1]\n", "config[\"negative_class_weight\"] = [20]\n", "\n", "config[\"learning_rates\"] = [\n", " 0.001,\n", "] # Learning rates for Adam optimizer - list that corresponds to training steps\n", "config[\"batch_size\"] = 128\n", "\n", "config[\"time_mask_max_size\"] = [\n", " 0\n", "] # SpecAugment - list that corresponds to training steps\n", "config[\"time_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", "config[\"freq_mask_max_size\"] = [\n", " 0\n", "] # SpecAugment - list that corresponds to training steps\n", "config[\"freq_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", "\n", "config[\"eval_step_interval\"] = (\n", " 500 # Test the validation sets after every this many steps\n", ")\n", "config[\"clip_duration_ms\"] = (\n", " 1500 # Maximum length of wake word that the streaming model will accept\n", ")\n", "\n", "# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization\n", "# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize\n", "# Available metrics:\n", "# - \"loss\" - cross entropy error on validation set\n", "# - \"accuracy\" - accuracy of validation set\n", "# - \"recall\" - recall of validation set\n", "# - \"precision\" - precision of validation set\n", "# - \"false_positive_rate\" - false positive rate of validation set\n", "# - \"false_negative_rate\" - false negative rate of validation set\n", "# - \"ambient_false_positives\" - count of false positives from the split validation_ambient set\n", "# - \"ambient_false_positives_per_hour\" - estimated number of false positives per hour on the split validation_ambient set\n", "config[\"target_minimization\"] = 0.9\n", "config[\"minimization_metric\"] = None # Set to None to disable\n", "\n", "config[\"maximization_metric\"] = \"average_viable_recall\"\n", "\n", "with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n", " documents = yaml.dump(config, file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WoEXJBaiC9mf" }, "outputs": [], "source": [ "# Trains a model. When finished, it will quantize and convert the model to a\n", "# streaming version suitable for on-device detection.\n", "# It will resume if stopped, but it will start over at the configured training\n", "# steps in the yaml file.\n", "# Change --train 0 to only convert and test the best-weighted model.\n", "# On Google colab, it doesn't print the mini-batch results, so it may appear\n", "# stuck for several minutes! Additionally, it is very slow compared to training\n", "# on a local GPU.\n", "\n", "!python -m microwakeword.model_train_eval \\\n", "--training_config='training_parameters.yaml' \\\n", "--train 1 \\\n", "--restore_checkpoint 1 \\\n", "--test_tf_nonstreaming 0 \\\n", "--test_tflite_nonstreaming 0 \\\n", "--test_tflite_nonstreaming_quantized 0 \\\n", "--test_tflite_streaming 0 \\\n", "--test_tflite_streaming_quantized 1 \\\n", "--use_weights \"best_weights\" \\\n", "mixednet \\\n", "--pointwise_filters \"64,64,64,64\" \\\n", "--repeat_in_block \"1, 1, 1, 1\" \\\n", "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", "--residual_connection \"0,0,0,0\" \\\n", "--first_conv_filters 32 \\\n", "--first_conv_kernel_size 5 \\\n", "--stride 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ex_UIWvwtjAN" }, "outputs": [], "source": [ "# Downloads the tflite model file. To use on the device, you need to write a\n", "# Model JSON file. See https://esphome.io/components/micro_wake_word for the\n", "# documentation and\n", "# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for\n", "# examples. Adjust the probability threshold based on the test results obtained\n", "# after training is finished. You may also need to increase the Tensor arena\n", "# model size if the model fails to load.\n", "\n", "from google.colab import files\n", "\n", "files.download(f\"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 0 }