From 621dc64b54ff90f0085036eeb8d4d2b892e8ea71 Mon Sep 17 00:00:00 2001 From: MasterPhooey <106418429+MasterPhooey@users.noreply.github.com> Date: Thu, 2 Jan 2025 20:23:13 -0600 Subject: [PATCH] Add files via upload --- basic_training_notebook.ipynb | 560 ++++++++++++++++++++++++++++++++++ 1 file changed, 560 insertions(+) create mode 100644 basic_training_notebook.ipynb diff --git a/basic_training_notebook.ipynb b/basic_training_notebook.ipynb new file mode 100644 index 0000000..3d84694 --- /dev/null +++ b/basic_training_notebook.ipynb @@ -0,0 +1,560 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "r11cNiLqvWC6" + }, + "source": [ + "# Training a microWakeWord Model\n", + "\n", + "This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.\n", + "\n", + "**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**\n", + "\n", + "In the comment at the start of certain blocks, I note some specific settings to consider modifying.\n", + "\n", + "This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!\n", + "\n", + "At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BFf6511E65ff" + }, + "outputs": [], + "source": [ + "# Installs microWakeWord. Be sure to restart the session after this is finished.\n", + "import platform\n", + "\n", + "if platform.system() == \"Darwin\":\n", + " # `pymicro-features` is installed from a fork to support building on macOS\n", + " !{sys.executable} pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'\n", + "\n", + "# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n", + "!{sys.executable} pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'\n", + "\n", + "!git clone -b https://github.com/kahrendt/microWakeWord\n", + "!{sys.executable} pip install -e ./microWakeWord" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dEluu7nL7ywd" + }, + "outputs": [], + "source": [ + "# Generates 1 sample of the target word for manual verification.\n", + "\n", + "target_word = 'khum_puter' # Phonetic spellings may produce better samples\n", + "\n", + "import os\n", + "import sys\n", + "import platform\n", + "\n", + "from IPython.display import Audio\n", + "\n", + "if not os.path.exists(\"./piper-sample-generator\"):\n", + " if platform.system() == \"Darwin\":\n", + " !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n", + " else:\n", + " !git clone https://github.com/rhasspy/piper-sample-generator\n", + "\n", + " !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n", + "\n", + " # Install system dependencies\n", + " !{sys.executable} pip install torch torchaudio piper-phonemize-cross==1.2.1\n", + "\n", + " if \"piper-sample-generator/\" not in sys.path:\n", + " sys.path.append(\"piper-sample-generator/\")\n", + "\n", + "!{sys.executable} pyrhon3 piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", + "--max-samples 1 \\\n", + "--batch-size 1 \\\n", + "--output-dir generated_samples\n", + "\n", + "Audio(\"generated_samples/0.wav\", autoplay=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-SvGtCCM9akR" + }, + "outputs": [], + "source": [ + "# Generates a larger amount of wake word samples.\n", + "# Start here when trying to improve your model.\n", + "# See https://github.com/rhasspy/piper-sample-generator for the full set of\n", + "# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n", + "# generating negative samples similar to the wake word, and generating many more\n", + "# wake word samples, possibly with different phonetic pronunciations.\n", + "\n", + "!{sys.executable} pyrhon3 piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", + "--max-samples 1000 \\\n", + "--batch-size 100 \\\n", + "--output-dir generated_samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YJRG4Qvo9nXG" + }, + "outputs": [], + "source": [ + "# Downloads audio data for augmentation. This can be slow!\n", + "# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n", + "#\n", + "# **Important note!** The data downloaded here has a mixture of difference\n", + "# licenses and usage restrictions. As such, any custom models trained with this\n", + "# data should be considered as appropriate for **non-commercial** personal use only.\n", + "\n", + "\n", + "import datasets\n", + "import scipy\n", + "import os\n", + "\n", + "import numpy as np\n", + "\n", + "from pathlib import Path\n", + "from tqdm import tqdm\n", + "\n", + "## Download MIR RIR data\n", + "\n", + "output_dir = \"./mit_rirs\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", + " # Save clips to 16-bit PCM wav files\n", + " for row in tqdm(rir_dataset):\n", + " name = row['audio']['path'].split('/')[-1]\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", + "\n", + "## Download noise and background audio\n", + "\n", + "# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n", + "# Download one part of the audioset .tar files, extract, and convert to 16khz\n", + "# For full-scale training, it's recommended to download the entire dataset from\n", + "# https://huggingface.co/datasets/agkphysics/AudioSet, and\n", + "# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n", + "\n", + "if not os.path.exists(\"audioset\"):\n", + " os.mkdir(\"audioset\")\n", + "\n", + " fname = \"bal_train09.tar\"\n", + " out_dir = f\"audioset/{fname}\"\n", + " link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n", + " !wget -O {out_dir} {link}\n", + " !cd audioset && tar -xf bal_train09.tar\n", + "\n", + " output_dir = \"./audioset_16k\"\n", + " if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + " # Save clips to 16-bit PCM wav files\n", + " audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n", + " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", + " for row in tqdm(audioset_dataset):\n", + " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", + "\n", + "# Free Music Archive dataset\n", + "# https://github.com/mdeff/fma\n", + "# (Third-party mchl914 extra small set)\n", + "\n", + "output_dir = \"./fma\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " fname = \"fma_xs.zip\"\n", + " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", + " out_dir = f\"fma/{fname}\"\n", + " !wget -O {out_dir} {link}\n", + " !cd {output_dir} && unzip -q {fname}\n", + "\n", + " output_dir = \"./fma_16k\"\n", + " if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + " # Save clips to 16-bit PCM wav files\n", + " fma_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"fma/fma_small\").glob(\"**/*.mp3\")]})\n", + " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", + " for row in tqdm(audioset_dataset):\n", + " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XW3bmbI5-JAz" + }, + "outputs": [], + "source": [ + "# Sets up the augmentations.\n", + "# To improve your model, experiment with these settings and use more sources of\n", + "# background clips.\n", + "\n", + "from microwakeword.audio.augmentation import Augmentation\n", + "from microwakeword.audio.clips import Clips\n", + "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", + "\n", + "clips = Clips(input_directory='generated_samples',\n", + " file_pattern='*.wav',\n", + " max_clip_duration_s=None,\n", + " remove_silence=False,\n", + " random_split_seed=10,\n", + " split_count=0.1,\n", + " )\n", + "augmenter = Augmentation(augmentation_duration_s=3.2,\n", + " augmentation_probabilities = {\n", + " \"SevenBandParametricEQ\": 0.1,\n", + " \"TanhDistortion\": 0.1,\n", + " \"PitchShift\": 0.1,\n", + " \"BandStopFilter\": 0.1,\n", + " \"AddColorNoise\": 0.1,\n", + " \"AddBackgroundNoise\": 0.75,\n", + " \"Gain\": 1.0,\n", + " \"RIR\": 0.5,\n", + " },\n", + " impulse_paths = ['mit_rirs'],\n", + " background_paths = ['fma_16k', 'audioset_16k'],\n", + " background_min_snr_db = -5,\n", + " background_max_snr_db = 10,\n", + " min_jitter_s = 0.195,\n", + " max_jitter_s = 0.205,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V5UsJfKKD1k9" + }, + "outputs": [], + "source": [ + "# Augment a random clip and play it back to verify it works well\n", + "\n", + "from IPython.display import Audio\n", + "from microwakeword.audio.audio_utils import save_clip\n", + "\n", + "random_clip = clips.get_random_clip()\n", + "augmented_clip = augmenter.augment_clip(random_clip)\n", + "save_clip(augmented_clip, 'augmented_clip.wav')\n", + "\n", + "Audio(\"augmented_clip.wav\", autoplay=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D7BHcY1mEGbK" + }, + "outputs": [], + "source": [ + "# Augment samples and save the training, validation, and testing sets.\n", + "# Validating and testing samples generated the same way can make the model\n", + "# benchmark better than it performs in real-word use. Use real samples or TTS\n", + "# samples generated with a different TTS engine to potentially get more accurate\n", + "# benchmarks.\n", + "\n", + "import os\n", + "from mmap_ninja.ragged import RaggedMmap\n", + "\n", + "output_dir = 'generated_augmented_features'\n", + "\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + "splits = [\"training\", \"validation\", \"testing\"]\n", + "for split in splits:\n", + " out_dir = os.path.join(output_dir, split)\n", + " if not os.path.exists(out_dir):\n", + " os.mkdir(out_dir)\n", + "\n", + "\n", + " split_name = \"train\"\n", + " repetition = 2\n", + "\n", + " spectrograms = SpectrogramGeneration(clips=clips,\n", + " augmenter=augmenter,\n", + " slide_frames=10, # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.\n", + " step_ms=10,\n", + " )\n", + " if split == \"validation\":\n", + " split_name = \"validation\"\n", + " repetition = 1\n", + " elif split == \"testing\":\n", + " split_name = \"test\"\n", + " repetition = 1\n", + " spectrograms = SpectrogramGeneration(clips=clips,\n", + " augmenter=augmenter,\n", + " slide_frames=1, # The testing set uses the streaming version of the model, so no artificial repetition is necessary\n", + " step_ms=10,\n", + " )\n", + "\n", + " RaggedMmap.from_generator(\n", + " out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n", + " sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),\n", + " batch_size=100,\n", + " verbose=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1pGuJDPyp3ax" + }, + "outputs": [], + "source": [ + "# Downloads pre-generated spectrogram features (made for microWakeWord in\n", + "# particular) for various negative datasets. This can be slow!\n", + "\n", + "output_dir = './negative_datasets'\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n", + " filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n", + " for fname in filenames:\n", + " link = link_root + fname\n", + "\n", + " zip_path = f\"negative_datasets/{fname}\"\n", + " !wget -O {zip_path} {link}\n", + " !unzip -q {zip_path} -d {output_dir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ii1A14GsGVQT" + }, + "outputs": [], + "source": [ + "# Save a yaml config that controls the training process\n", + "# These hyperparamters can make a huge different in model quality.\n", + "# Experiment with sampling and penalty weights and increasing the number of\n", + "# training steps.\n", + "\n", + "import yaml\n", + "import os\n", + "\n", + "config = {}\n", + "\n", + "config[\"window_step_ms\"] = 10\n", + "\n", + "config[\"train_dir\"] = (\n", + " \"trained_models/wakeword\"\n", + ")\n", + "\n", + "\n", + "# Each feature_dir should have at least one of the following folders with this structure:\n", + "# training/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# testing/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# testing_ambient/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# validation/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# validation_ambient/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "#\n", + "# sampling_weight: Weight for choosing a spectrogram from this set in the batch\n", + "# penalty_weight: Penalizing weight for incorrect predictions from this set\n", + "# truth: Boolean whether this set has positive samples or negative samples\n", + "# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated\n", + "# - random: choose a random portion of the entire spectrogram - useful for long negative samples\n", + "# - truncate_start: remove the start of the spectrogram\n", + "# - truncate_end: remove the end of the spectrogram\n", + "# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets\n", + "\n", + "config[\"features\"] = [\n", + " {\n", + " \"features_dir\": \"generated_augmented_features\",\n", + " \"sampling_weight\": 2.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": True,\n", + " \"truncation_strategy\": \"truncate_start\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " {\n", + " \"features_dir\": \"negative_datasets/speech\",\n", + " \"sampling_weight\": 10.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"random\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " {\n", + " \"features_dir\": \"negative_datasets/dinner_party\",\n", + " \"sampling_weight\": 10.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"random\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " {\n", + " \"features_dir\": \"negative_datasets/no_speech\",\n", + " \"sampling_weight\": 5.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"random\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " { # Only used for validation and testing\n", + " \"features_dir\": \"negative_datasets/dinner_party_eval\",\n", + " \"sampling_weight\": 0.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"split\",\n", + " \"type\": \"mmap\",\n", + " },\n", + "]\n", + "\n", + "# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps\n", + "config[\"training_steps\"] = [10000]\n", + "\n", + "# Penalizing weight for incorrect class predictions - lists that correspond to training steps\n", + "config[\"positive_class_weight\"] = [1]\n", + "config[\"negative_class_weight\"] = [20]\n", + "\n", + "config[\"learning_rates\"] = [\n", + " 0.001,\n", + "] # Learning rates for Adam optimizer - list that corresponds to training steps\n", + "config[\"batch_size\"] = 128\n", + "\n", + "config[\"time_mask_max_size\"] = [\n", + " 0\n", + "] # SpecAugment - list that corresponds to training steps\n", + "config[\"time_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", + "config[\"freq_mask_max_size\"] = [\n", + " 0\n", + "] # SpecAugment - list that corresponds to training steps\n", + "config[\"freq_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", + "\n", + "config[\"eval_step_interval\"] = (\n", + " 500 # Test the validation sets after every this many steps\n", + ")\n", + "config[\"clip_duration_ms\"] = (\n", + " 1500 # Maximum length of wake word that the streaming model will accept\n", + ")\n", + "\n", + "# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization\n", + "# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize\n", + "# Available metrics:\n", + "# - \"loss\" - cross entropy error on validation set\n", + "# - \"accuracy\" - accuracy of validation set\n", + "# - \"recall\" - recall of validation set\n", + "# - \"precision\" - precision of validation set\n", + "# - \"false_positive_rate\" - false positive rate of validation set\n", + "# - \"false_negative_rate\" - false negative rate of validation set\n", + "# - \"ambient_false_positives\" - count of false positives from the split validation_ambient set\n", + "# - \"ambient_false_positives_per_hour\" - estimated number of false positives per hour on the split validation_ambient set\n", + "config[\"target_minimization\"] = 0.9\n", + "config[\"minimization_metric\"] = None # Set to None to disable\n", + "\n", + "config[\"maximization_metric\"] = \"average_viable_recall\"\n", + "\n", + "with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n", + " documents = yaml.dump(config, file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WoEXJBaiC9mf" + }, + "outputs": [], + "source": [ + "# Trains a model. When finished, it will quantize and convert the model to a\n", + "# streaming version suitable for on-device detection.\n", + "# It will resume if stopped, but it will start over at the configured training\n", + "# steps in the yaml file.\n", + "# Change --train 0 to only convert and test the best-weighted model.\n", + "# On Google colab, it doesn't print the mini-batch results, so it may appear\n", + "# stuck for several minutes! Additionally, it is very slow compared to training\n", + "# on a local GPU.\n", + "\n", + "!python -m microwakeword.model_train_eval \\\n", + "--training_config='training_parameters.yaml' \\\n", + "--train 1 \\\n", + "--restore_checkpoint 1 \\\n", + "--test_tf_nonstreaming 0 \\\n", + "--test_tflite_nonstreaming 0 \\\n", + "--test_tflite_nonstreaming_quantized 0 \\\n", + "--test_tflite_streaming 0 \\\n", + "--test_tflite_streaming_quantized 1 \\\n", + "--use_weights \"best_weights\" \\\n", + "mixednet \\\n", + "--pointwise_filters \"64,64,64,64\" \\\n", + "--repeat_in_block \"1, 1, 1, 1\" \\\n", + "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", + "--residual_connection \"0,0,0,0\" \\\n", + "--first_conv_filters 32 \\\n", + "--first_conv_kernel_size 5 \\\n", + "--stride 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ex_UIWvwtjAN" + }, + "outputs": [], + "source": [ + "# Downloads the tflite model file. To use on the device, you need to write a\n", + "# Model JSON file. See https://esphome.io/components/micro_wake_word for the\n", + "# documentation and\n", + "# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for\n", + "# examples. Adjust the probability threshold based on the test results obtained\n", + "# after training is finished. You may also need to increase the Tensor arena\n", + "# model size if the model fails to load.\n", + "\n", + "from google.colab import files\n", + "\n", + "files.download(f\"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file