mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-13 04:20:19 -06:00
619 lines
25 KiB
Plaintext
619 lines
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "r11cNiLqvWC6"
|
|
},
|
|
"source": [
|
|
"<div align=\"center\">\n",
|
|
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
|
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.\n",
|
|
"\n",
|
|
"**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**\n",
|
|
"\n",
|
|
"In the comment at the start of certain blocks, I note some specific settings to consider modifying.\n",
|
|
"\n",
|
|
"This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!\n",
|
|
"\n",
|
|
"At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "BFf6511E65ff"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Installs microWakeWord. Be sure to restart the session after this is finished.\n",
|
|
"import platform\n",
|
|
"import sys\n",
|
|
"import os\n",
|
|
"\n",
|
|
"if platform.system() == \"Darwin\":\n",
|
|
" # `pymicro-features` is installed from a fork to support building on macOS\n",
|
|
" !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'\n",
|
|
"\n",
|
|
"# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n",
|
|
"!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'\n",
|
|
"\n",
|
|
"# Clone the microWakeWord repository\n",
|
|
"repo_path = \"./microWakeWord\"\n",
|
|
"if not os.path.exists(repo_path):\n",
|
|
" print(\"Cloning microWakeWord repository...\")\n",
|
|
" !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n",
|
|
"\n",
|
|
"# Ensure the repository exists before attempting to install\n",
|
|
"if os.path.exists(repo_path):\n",
|
|
" print(\"Installing microWakeWord...\")\n",
|
|
" !\"{sys.executable}\" -m pip install -e {repo_path}\n",
|
|
"else:\n",
|
|
" print(f\"Repository not found at {repo_path}. Cloning might have failed.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "dEluu7nL7ywd"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generates 1 sample of the target word for manual verification.\n",
|
|
"\n",
|
|
"target_word = 'khum_puter' # Phonetic spellings may produce better samples\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import platform\n",
|
|
"\n",
|
|
"from IPython.display import Audio\n",
|
|
"\n",
|
|
"# Ensure the repository is cloned correctly\n",
|
|
"if not os.path.exists(\"./piper-sample-generator\"):\n",
|
|
" if platform.system() == \"Darwin\":\n",
|
|
" !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n",
|
|
" else:\n",
|
|
" !git clone https://github.com/rhasspy/piper-sample-generator\n",
|
|
"\n",
|
|
"# Download the required model\n",
|
|
"if not os.path.exists(\"piper-sample-generator/models/en_US-libritts_r-medium.pt\"):\n",
|
|
" !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n",
|
|
"\n",
|
|
"# Install system dependencies\n",
|
|
"!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1\n",
|
|
"\n",
|
|
"# Ensure the repository path is in sys.path\n",
|
|
"if \"piper-sample-generator/\" not in sys.path:\n",
|
|
" sys.path.append(\"piper-sample-generator/\")\n",
|
|
"\n",
|
|
"# Generate sample\n",
|
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
|
"--max-samples 1 \\\n",
|
|
"--batch-size 1 \\\n",
|
|
"--output-dir generated_samples\n",
|
|
"\n",
|
|
"# Play the generated audio sample\n",
|
|
"audio_path = \"generated_samples/0.wav\"\n",
|
|
"if os.path.exists(audio_path):\n",
|
|
" display(Audio(audio_path, autoplay=True))\n",
|
|
"else:\n",
|
|
" print(f\"Audio file not found at {audio_path}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "-SvGtCCM9akR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generates a larger amount of wake word samples.\n",
|
|
"# Start here when trying to improve your model.\n",
|
|
"# See https://github.com/rhasspy/-m piper-sample-generator for the full set of\n",
|
|
"# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n",
|
|
"# generating negative samples similar to the wake word, and generating many more\n",
|
|
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
|
"\n",
|
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
|
"--max-samples 1000 \\\n",
|
|
"--batch-size 100 \\\n",
|
|
"--output-dir generated_samples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "YJRG4Qvo9nXG"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Downloads audio data for augmentation. This can be slow!\n",
|
|
"# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n",
|
|
"#\n",
|
|
"# **Important note!** The data downloaded here has a mixture of difference\n",
|
|
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
|
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
|
"\n",
|
|
"\n",
|
|
"import datasets\n",
|
|
"import scipy\n",
|
|
"import os\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"## Download MIR RIR data\n",
|
|
"\n",
|
|
"output_dir = \"./mit_rirs\"\n",
|
|
"if not os.path.exists(output_dir):\n",
|
|
" os.mkdir(output_dir)\n",
|
|
" rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
|
|
" # Save clips to 16-bit PCM wav files\n",
|
|
" for row in tqdm(rir_dataset):\n",
|
|
" name = row['audio']['path'].split('/')[-1]\n",
|
|
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n",
|
|
"\n",
|
|
"## Download noise and background audio\n",
|
|
"\n",
|
|
"# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n",
|
|
"# Download one part of the audioset .tar files, extract, and convert to 16khz\n",
|
|
"# For full-scale training, it's recommended to download the entire dataset from\n",
|
|
"# https://huggingface.co/datasets/agkphysics/AudioSet, and\n",
|
|
"# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n",
|
|
"\n",
|
|
"if not os.path.exists(\"audioset\"):\n",
|
|
" os.mkdir(\"audioset\")\n",
|
|
"\n",
|
|
" fname = \"bal_train09.tar\"\n",
|
|
" out_dir = f\"audioset/{fname}\"\n",
|
|
" link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n",
|
|
" !wget -O {out_dir} {link}\n",
|
|
" !cd audioset && tar -xf bal_train09.tar\n",
|
|
"\n",
|
|
" output_dir = \"./audioset_16k\"\n",
|
|
" if not os.path.exists(output_dir):\n",
|
|
" os.mkdir(output_dir)\n",
|
|
"\n",
|
|
" # Save clips to 16-bit PCM wav files\n",
|
|
" audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n",
|
|
" audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n",
|
|
" for row in tqdm(audioset_dataset):\n",
|
|
" name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n",
|
|
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n",
|
|
"\n",
|
|
"# Free Music Archive dataset\n",
|
|
"# https://github.com/mdeff/fma\n",
|
|
"# (Third-party mchl914 extra small set)\n",
|
|
"\n",
|
|
"output_dir = \"./fma\"\n",
|
|
"if not os.path.exists(output_dir):\n",
|
|
" os.mkdir(output_dir)\n",
|
|
" fname = \"fma_xs.zip\"\n",
|
|
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
|
" out_dir = f\"fma/{fname}\"\n",
|
|
" !wget -O {out_dir} {link}\n",
|
|
" !cd {output_dir} && unzip -q {fname}\n",
|
|
"\n",
|
|
" output_dir = \"./fma_16k\"\n",
|
|
" if not os.path.exists(output_dir):\n",
|
|
" os.mkdir(output_dir)\n",
|
|
"\n",
|
|
" # Save clips to 16-bit PCM wav files\n",
|
|
" fma_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"fma/fma_small\").glob(\"**/*.mp3\")]})\n",
|
|
" audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n",
|
|
" for row in tqdm(audioset_dataset):\n",
|
|
" name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n",
|
|
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "XW3bmbI5-JAz"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Sets up the augmentations.\n",
|
|
"# To improve your model, experiment with these settings and use more sources of\n",
|
|
"# background clips.\n",
|
|
"\n",
|
|
"from microwakeword.audio.augmentation import Augmentation\n",
|
|
"from microwakeword.audio.clips import Clips\n",
|
|
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
|
"\n",
|
|
"clips = Clips(input_directory='./generated_samples',\n",
|
|
" file_pattern='*.wav',\n",
|
|
" max_clip_duration_s=None,\n",
|
|
" remove_silence=False,\n",
|
|
" random_split_seed=10,\n",
|
|
" split_count=0.1,\n",
|
|
" )\n",
|
|
"augmenter = Augmentation(augmentation_duration_s=3.2,\n",
|
|
" augmentation_probabilities = {\n",
|
|
" \"SevenBandParametricEQ\": 0.1,\n",
|
|
" \"TanhDistortion\": 0.1,\n",
|
|
" \"PitchShift\": 0.1,\n",
|
|
" \"BandStopFilter\": 0.1,\n",
|
|
" \"AddColorNoise\": 0.1,\n",
|
|
" \"AddBackgroundNoise\": 0.75,\n",
|
|
" \"Gain\": 1.0,\n",
|
|
" \"RIR\": 0.5,\n",
|
|
" },\n",
|
|
" impulse_paths = ['mit_rirs'],\n",
|
|
" background_paths = ['fma_16k', 'audioset_16k'],\n",
|
|
" background_min_snr_db = -5,\n",
|
|
" background_max_snr_db = 10,\n",
|
|
" min_jitter_s = 0.195,\n",
|
|
" max_jitter_s = 0.205,\n",
|
|
" )\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "V5UsJfKKD1k9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Augment a random clip and play it back to verify it works well\n",
|
|
"\n",
|
|
"from IPython.display import Audio\n",
|
|
"from microwakeword.audio.audio_utils import save_clip\n",
|
|
"\n",
|
|
"random_clip = clips.get_random_clip()\n",
|
|
"augmented_clip = augmenter.augment_clip(random_clip)\n",
|
|
"save_clip(augmented_clip, 'augmented_clip.wav')\n",
|
|
"\n",
|
|
"Audio(\"augmented_clip.wav\", autoplay=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "D7BHcY1mEGbK"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Augment samples and save the training, validation, and testing sets.\n",
|
|
"# Validating and testing samples generated the same way can make the model\n",
|
|
"# benchmark better than it performs in real-word use. Use real samples or TTS\n",
|
|
"# samples generated with a different TTS engine to potentially get more accurate\n",
|
|
"# benchmarks.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from mmap_ninja.ragged import RaggedMmap\n",
|
|
"\n",
|
|
"output_dir = 'generated_augmented_features'\n",
|
|
"\n",
|
|
"if not os.path.exists(output_dir):\n",
|
|
" os.mkdir(output_dir)\n",
|
|
"\n",
|
|
"splits = [\"training\", \"validation\", \"testing\"]\n",
|
|
"for split in splits:\n",
|
|
" out_dir = os.path.join(output_dir, split)\n",
|
|
" if not os.path.exists(out_dir):\n",
|
|
" os.mkdir(out_dir)\n",
|
|
"\n",
|
|
"\n",
|
|
" split_name = \"train\"\n",
|
|
" repetition = 2\n",
|
|
"\n",
|
|
" spectrograms = SpectrogramGeneration(clips=clips,\n",
|
|
" augmenter=augmenter,\n",
|
|
" slide_frames=10, # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.\n",
|
|
" step_ms=10,\n",
|
|
" )\n",
|
|
" if split == \"validation\":\n",
|
|
" split_name = \"validation\"\n",
|
|
" repetition = 1\n",
|
|
" elif split == \"testing\":\n",
|
|
" split_name = \"test\"\n",
|
|
" repetition = 1\n",
|
|
" spectrograms = SpectrogramGeneration(clips=clips,\n",
|
|
" augmenter=augmenter,\n",
|
|
" slide_frames=1, # The testing set uses the streaming version of the model, so no artificial repetition is necessary\n",
|
|
" step_ms=10,\n",
|
|
" )\n",
|
|
"\n",
|
|
" RaggedMmap.from_generator(\n",
|
|
" out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n",
|
|
" sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),\n",
|
|
" batch_size=100,\n",
|
|
" verbose=True,\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "1pGuJDPyp3ax"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Downloads pre-generated spectrogram features (made for microWakeWord in\n",
|
|
"# particular) for various negative datasets. This can be slow!\n",
|
|
"\n",
|
|
"output_dir = './negative_datasets'\n",
|
|
"if not os.path.exists(output_dir):\n",
|
|
" os.mkdir(output_dir)\n",
|
|
" link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n",
|
|
" filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n",
|
|
" for fname in filenames:\n",
|
|
" link = link_root + fname\n",
|
|
"\n",
|
|
" zip_path = f\"negative_datasets/{fname}\"\n",
|
|
" !wget -O {zip_path} {link}\n",
|
|
" !unzip -q {zip_path} -d {output_dir}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Ii1A14GsGVQT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save a yaml config that controls the training process\n",
|
|
"# These hyperparamters can make a huge different in model quality.\n",
|
|
"# Experiment with sampling and penalty weights and increasing the number of\n",
|
|
"# training steps.\n",
|
|
"\n",
|
|
"import yaml\n",
|
|
"import os\n",
|
|
"\n",
|
|
"config = {}\n",
|
|
"\n",
|
|
"config[\"window_step_ms\"] = 10\n",
|
|
"\n",
|
|
"config[\"train_dir\"] = (\n",
|
|
" \"trained_models/wakeword\"\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"# Each feature_dir should have at least one of the following folders with this structure:\n",
|
|
"# training/\n",
|
|
"# ragged_mmap_folders_ending_in_mmap\n",
|
|
"# testing/\n",
|
|
"# ragged_mmap_folders_ending_in_mmap\n",
|
|
"# testing_ambient/\n",
|
|
"# ragged_mmap_folders_ending_in_mmap\n",
|
|
"# validation/\n",
|
|
"# ragged_mmap_folders_ending_in_mmap\n",
|
|
"# validation_ambient/\n",
|
|
"# ragged_mmap_folders_ending_in_mmap\n",
|
|
"#\n",
|
|
"# sampling_weight: Weight for choosing a spectrogram from this set in the batch\n",
|
|
"# penalty_weight: Penalizing weight for incorrect predictions from this set\n",
|
|
"# truth: Boolean whether this set has positive samples or negative samples\n",
|
|
"# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated\n",
|
|
"# - random: choose a random portion of the entire spectrogram - useful for long negative samples\n",
|
|
"# - truncate_start: remove the start of the spectrogram\n",
|
|
"# - truncate_end: remove the end of the spectrogram\n",
|
|
"# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets\n",
|
|
"\n",
|
|
"config[\"features\"] = [\n",
|
|
" {\n",
|
|
" \"features_dir\": \"generated_augmented_features\",\n",
|
|
" \"sampling_weight\": 2.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": True,\n",
|
|
" \"truncation_strategy\": \"truncate_start\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/speech\",\n",
|
|
" \"sampling_weight\": 10.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
|
" \"sampling_weight\": 10.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
|
" \"sampling_weight\": 5.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"random\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
" { # Only used for validation and testing\n",
|
|
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
|
" \"sampling_weight\": 0.0,\n",
|
|
" \"penalty_weight\": 1.0,\n",
|
|
" \"truth\": False,\n",
|
|
" \"truncation_strategy\": \"split\",\n",
|
|
" \"type\": \"mmap\",\n",
|
|
" },\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps\n",
|
|
"config[\"training_steps\"] = [10000]\n",
|
|
"\n",
|
|
"# Penalizing weight for incorrect class predictions - lists that correspond to training steps\n",
|
|
"config[\"positive_class_weight\"] = [1]\n",
|
|
"config[\"negative_class_weight\"] = [20]\n",
|
|
"\n",
|
|
"config[\"learning_rates\"] = [\n",
|
|
" 0.001,\n",
|
|
"] # Learning rates for Adam optimizer - list that corresponds to training steps\n",
|
|
"config[\"batch_size\"] = 128\n",
|
|
"\n",
|
|
"config[\"time_mask_max_size\"] = [\n",
|
|
" 0\n",
|
|
"] # SpecAugment - list that corresponds to training steps\n",
|
|
"config[\"time_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n",
|
|
"config[\"freq_mask_max_size\"] = [\n",
|
|
" 0\n",
|
|
"] # SpecAugment - list that corresponds to training steps\n",
|
|
"config[\"freq_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n",
|
|
"\n",
|
|
"config[\"eval_step_interval\"] = (\n",
|
|
" 500 # Test the validation sets after every this many steps\n",
|
|
")\n",
|
|
"config[\"clip_duration_ms\"] = (\n",
|
|
" 1500 # Maximum length of wake word that the streaming model will accept\n",
|
|
")\n",
|
|
"\n",
|
|
"# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization\n",
|
|
"# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize\n",
|
|
"# Available metrics:\n",
|
|
"# - \"loss\" - cross entropy error on validation set\n",
|
|
"# - \"accuracy\" - accuracy of validation set\n",
|
|
"# - \"recall\" - recall of validation set\n",
|
|
"# - \"precision\" - precision of validation set\n",
|
|
"# - \"false_positive_rate\" - false positive rate of validation set\n",
|
|
"# - \"false_negative_rate\" - false negative rate of validation set\n",
|
|
"# - \"ambient_false_positives\" - count of false positives from the split validation_ambient set\n",
|
|
"# - \"ambient_false_positives_per_hour\" - estimated number of false positives per hour on the split validation_ambient set\n",
|
|
"config[\"target_minimization\"] = 0.9\n",
|
|
"config[\"minimization_metric\"] = None # Set to None to disable\n",
|
|
"\n",
|
|
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
|
"\n",
|
|
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
|
" documents = yaml.dump(config, file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "WoEXJBaiC9mf"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Trains a model. When finished, it will quantize and convert the model to a\n",
|
|
"# streaming version suitable for on-device detection.\n",
|
|
"# It will resume if stopped, but it will start over at the configured training\n",
|
|
"# steps in the yaml file.\n",
|
|
"# Change --train 0 to only convert and test the best-weighted model.\n",
|
|
"# On Google colab, it doesn't print the mini-batch results, so it may appear\n",
|
|
"# stuck for several minutes! Additionally, it is very slow compared to training\n",
|
|
"# on a local GPU.\n",
|
|
"import os\n",
|
|
"os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n",
|
|
"\n",
|
|
"!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
|
|
"--training_config='training_parameters.yaml' \\\n",
|
|
"--train 1 \\\n",
|
|
"--restore_checkpoint 1 \\\n",
|
|
"--test_tf_nonstreaming 0 \\\n",
|
|
"--test_tflite_nonstreaming 0 \\\n",
|
|
"--test_tflite_nonstreaming_quantized 0 \\\n",
|
|
"--test_tflite_streaming 0 \\\n",
|
|
"--test_tflite_streaming_quantized 1 \\\n",
|
|
"--use_weights \"best_weights\" \\\n",
|
|
"mixednet \\\n",
|
|
"--pointwise_filters \"64,64,64,64\" \\\n",
|
|
"--repeat_in_block \"1, 1, 1, 1\" \\\n",
|
|
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n",
|
|
"--residual_connection \"0,0,0,0\" \\\n",
|
|
"--first_conv_filters 32 \\\n",
|
|
"--first_conv_kernel_size 5 \\\n",
|
|
"--stride 3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ex_UIWvwtjAN"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import shutil\n",
|
|
"import json\n",
|
|
"from IPython.display import FileLink\n",
|
|
"\n",
|
|
"# Define the source path and desired download location for the TFLite file\n",
|
|
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
|
"destination_path = \"./stream_state_internal_quant.tflite\"\n",
|
|
"\n",
|
|
"# Copy the TFLite file to the current working directory\n",
|
|
"shutil.copy(source_path, destination_path)\n",
|
|
"\n",
|
|
"# Define the JSON file content\n",
|
|
"json_data = {\n",
|
|
" \"type\": \"micro\",\n",
|
|
" \"wake_word\": \"khum_puter\", # Adjust this if the target_word changes dynamically\n",
|
|
" \"author\": \"master phooey\",\n",
|
|
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
|
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
|
" \"trained_languages\": [\"en\"],\n",
|
|
" \"version\": 2,\n",
|
|
" \"micro\": {\n",
|
|
" \"probability_cutoff\": 0.97,\n",
|
|
" \"sliding_window_size\": 5,\n",
|
|
" \"feature_step_size\": 10,\n",
|
|
" \"tensor_arena_size\": 30000,\n",
|
|
" \"minimum_esphome_version\": \"2024.7.0\"\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Define the JSON file path\n",
|
|
"json_path = \"./stream_state_internal_quant.json\"\n",
|
|
"\n",
|
|
"# Write the JSON file\n",
|
|
"with open(json_path, \"w\") as json_file:\n",
|
|
" json.dump(json_data, json_file, indent=2)\n",
|
|
"\n",
|
|
"# Generate download links for both files\n",
|
|
"print(\"Download your files:\")\n",
|
|
"print(\"TFLite Model:\")\n",
|
|
"display(FileLink(destination_path))\n",
|
|
"print(\"\\nJSON Metadata:\")\n",
|
|
"display(FileLink(json_path))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"gpuType": "T4",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.15"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|