mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Delete advanced_training_notebook.ipynb
This commit is contained in:
@@ -1,715 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "r11cNiLqvWC6"
|
||||
},
|
||||
"source": [
|
||||
"<div align=\"center\">\n",
|
||||
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
||||
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
|
||||
"\n",
|
||||
"**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**\n",
|
||||
"\n",
|
||||
"Throughout the notebook, you will find comments suggesting specific settings to modify and experiment with to enhance your model's performance.\n",
|
||||
"\n",
|
||||
"By the end of this notebook, you will have:\n",
|
||||
"- A trained TensorFlow Lite model ready for deployment.\n",
|
||||
"- A JSON manifest file to integrate the model with ESPHome.\n",
|
||||
"\n",
|
||||
"To use the generated model in ESPHome, refer to the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for integration details. You can also explore example configurations in the [model repository](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "BFf6511E65ff"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Installs microWakeWord. Be sure to restart the session after this is finished.\n",
|
||||
"import platform\n",
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if platform.system() == \"Darwin\":\n",
|
||||
" # `pymicro-features` is installed from a fork to support building on macOS\n",
|
||||
" !\"{sys.executable}\" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore\n",
|
||||
"\n",
|
||||
"# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n",
|
||||
"!\"{sys.executable}\" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore\n",
|
||||
"\n",
|
||||
"# Clone the microWakeWord repository\n",
|
||||
"repo_path = \"./microWakeWord\"\n",
|
||||
"if not os.path.exists(repo_path):\n",
|
||||
" print(\"Cloning microWakeWord repository...\")\n",
|
||||
" !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}\n",
|
||||
"\n",
|
||||
"# Ensure the repository exists before attempting to install\n",
|
||||
"if os.path.exists(repo_path):\n",
|
||||
" print(\"Installing microWakeWord...\")\n",
|
||||
" !\"{sys.executable}\" -m pip install -e {repo_path} --root-user-action=ignore\n",
|
||||
"else:\n",
|
||||
" print(f\"Repository not found at {repo_path}. Cloning might have failed.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dEluu7nL7ywd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generates 1 sample of the target word for manual verification.\n",
|
||||
"\n",
|
||||
"target_word = 'hey_norman' # Phonetic spellings may produce better samples\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import platform\n",
|
||||
"\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"\n",
|
||||
"# Ensure the repository is cloned correctly\n",
|
||||
"if not os.path.exists(\"./piper-sample-generator\"):\n",
|
||||
" if platform.system() == \"Darwin\":\n",
|
||||
" !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n",
|
||||
" else:\n",
|
||||
" !git clone https://github.com/rhasspy/piper-sample-generator\n",
|
||||
"\n",
|
||||
"# Download the required model\n",
|
||||
"if not os.path.exists(\"piper-sample-generator/models/en_US-libritts_r-medium.pt\"):\n",
|
||||
" !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n",
|
||||
"\n",
|
||||
"# Install system dependencies\n",
|
||||
"!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1\n",
|
||||
"\n",
|
||||
"# Ensure the repository path is in sys.path\n",
|
||||
"if \"piper-sample-generator/\" not in sys.path:\n",
|
||||
" sys.path.append(\"piper-sample-generator/\")\n",
|
||||
"\n",
|
||||
"# Generate sample\n",
|
||||
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
||||
"--max-samples 1 \\\n",
|
||||
"--batch-size 1 \\\n",
|
||||
"--output-dir generated_samples\n",
|
||||
"\n",
|
||||
"# Play the generated audio sample\n",
|
||||
"audio_path = \"generated_samples/0.wav\"\n",
|
||||
"if os.path.exists(audio_path):\n",
|
||||
" display(Audio(audio_path, autoplay=True))\n",
|
||||
"else:\n",
|
||||
" print(f\"Audio file not found at {audio_path}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "-SvGtCCM9akR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generates a larger amount of wake word samples.\n",
|
||||
"# Start here when trying to improve your model.\n",
|
||||
"# See https://github.com/rhasspy/-m piper-sample-generator for the full set of\n",
|
||||
"# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n",
|
||||
"# generating negative samples similar to the wake word, and generating many more\n",
|
||||
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
||||
"\n",
|
||||
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
||||
"--max-samples 50000 \\\n",
|
||||
"--batch-size 100 \\\n",
|
||||
"--output-dir generated_samples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "YJRG4Qvo9nXG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Downloads audio data for augmentation. This can be slow!\n",
|
||||
"# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n",
|
||||
"#\n",
|
||||
"# **Important note!** The data downloaded here has a mixture of difference\n",
|
||||
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
||||
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import scipy.io.wavfile\n",
|
||||
"import numpy as np\n",
|
||||
"from datasets import Dataset, Audio, load_dataset\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import soundfile as sf\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# Download and Process MIT RIR\n",
|
||||
"# -----------------------------\n",
|
||||
"output_dir = \"./mit_rirs\"\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.mkdir(output_dir)\n",
|
||||
" rir_dataset = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
|
||||
" print(f\"Downloading MIT RIR dataset to {output_dir}...\")\n",
|
||||
" for row in tqdm(rir_dataset):\n",
|
||||
" name = row[\"audio\"][\"path\"].split(\"/\")[-1]\n",
|
||||
" scipy.io.wavfile.write(\n",
|
||||
" os.path.join(output_dir, name), \n",
|
||||
" 16000, \n",
|
||||
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||
" )\n",
|
||||
" print(f\"Finished downloading MIT RIR dataset to {output_dir}.\\n\")\n",
|
||||
"else:\n",
|
||||
" print(f\"{output_dir} already exists. Skipping download.\")\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# Download and Process Audioset\n",
|
||||
"# -----------------------------\n",
|
||||
"\n",
|
||||
"# Directory setup\n",
|
||||
"audioset_dir = \"./audioset\"\n",
|
||||
"output_dir = \"./audioset_16k\"\n",
|
||||
"os.makedirs(audioset_dir, exist_ok=True)\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
"# Full-scale dataset download links\n",
|
||||
"dataset_links = [\n",
|
||||
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
||||
" for i in range(10)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Download and extract each dataset part\n",
|
||||
"for link in dataset_links:\n",
|
||||
" file_name = link.split(\"/\")[-1]\n",
|
||||
" out_path = os.path.join(audioset_dir, file_name)\n",
|
||||
" if not os.path.exists(out_path):\n",
|
||||
" print(f\"Downloading {file_name}...\")\n",
|
||||
" os.system(f\"wget --quiet -O {out_path} {link}\")\n",
|
||||
" print(f\"Extracting {file_name}...\")\n",
|
||||
" os.system(f\"tar -xf {out_path} -C {audioset_dir}\")\n",
|
||||
"\n",
|
||||
"# Collect all FLAC files for processing\n",
|
||||
"audioset_files = list(Path(audioset_dir).glob(\"**/*.flac\"))\n",
|
||||
"print(f\"Number of FLAC files found: {len(audioset_files)}\")\n",
|
||||
"\n",
|
||||
"if audioset_files:\n",
|
||||
" corrupted_files = []\n",
|
||||
"\n",
|
||||
" print(\"Converting Audioset files to 16kHz WAV...\")\n",
|
||||
" for file_path in tqdm(audioset_files, desc=\"Processing Audioset files\"):\n",
|
||||
" try:\n",
|
||||
" # Attempt to load the file and handle any errors\n",
|
||||
" audio, sampling_rate = sf.read(file_path)\n",
|
||||
" \n",
|
||||
" if audio is None or len(audio) == 0:\n",
|
||||
" raise ValueError(f\"Empty or invalid audio data in file: {file_path}\")\n",
|
||||
"\n",
|
||||
" # Resample audio to 16kHz\n",
|
||||
" output_path = Path(output_dir) / (file_path.stem + \".wav\")\n",
|
||||
" scipy.io.wavfile.write(\n",
|
||||
" output_path,\n",
|
||||
" 16000,\n",
|
||||
" (audio * 32767).astype(np.int16),\n",
|
||||
" )\n",
|
||||
" except (sf.LibsndfileError, ValueError, Exception) as e:\n",
|
||||
" # Log the error and skip the file\n",
|
||||
" print(f\"Error converting {file_path}: {e}\")\n",
|
||||
" corrupted_files.append(str(file_path))\n",
|
||||
"\n",
|
||||
" # Log corrupted files\n",
|
||||
" if corrupted_files:\n",
|
||||
" log_path = Path(output_dir) / \"audioset_corrupted_files.log\"\n",
|
||||
" with open(log_path, \"w\") as log_file:\n",
|
||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||
" print(f\"Logged corrupted files to {log_path}\")\n",
|
||||
"else:\n",
|
||||
" print(\"No FLAC files found in Audioset.\")\n",
|
||||
"\n",
|
||||
"print(\"Audioset processing complete!\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# -----------------------------\n",
|
||||
"# Download and Process FMA\n",
|
||||
"# -----------------------------\n",
|
||||
"output_dir = \"./fma\"\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.mkdir(output_dir)\n",
|
||||
" fname = \"fma_xs.zip\"\n",
|
||||
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
||||
" out_dir = os.path.join(output_dir, fname)\n",
|
||||
" os.system(f\"wget -q -O {out_dir} {link}\")\n",
|
||||
" os.system(f\"cd {output_dir} && unzip -q {fname}\")\n",
|
||||
"\n",
|
||||
"output_dir = \"./fma_16k\"\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.mkdir(output_dir)\n",
|
||||
"\n",
|
||||
"# Save clips to 16-bit PCM wav files\n",
|
||||
"fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
||||
"print(f\"Number of MP3 files found: {len(fma_files)}\")\n",
|
||||
"if fma_files:\n",
|
||||
" fma_dataset = Dataset.from_dict({\"audio\": [str(file) for file in fma_files]})\n",
|
||||
" fma_dataset = fma_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
|
||||
"\n",
|
||||
" corrupted_files = []\n",
|
||||
" print(\"Converting FMA files to 16kHz WAV...\")\n",
|
||||
" for row in tqdm(fma_dataset):\n",
|
||||
" try:\n",
|
||||
" name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n",
|
||||
" scipy.io.wavfile.write(\n",
|
||||
" os.path.join(output_dir, name), \n",
|
||||
" 16000, \n",
|
||||
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||
" )\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error converting {row['audio']['path']}: {e}\")\n",
|
||||
" corrupted_files.append(row[\"audio\"][\"path\"])\n",
|
||||
"\n",
|
||||
" if corrupted_files:\n",
|
||||
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
|
||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||
"else:\n",
|
||||
" print(\"No MP3 files found in FMA.\")\n",
|
||||
"\n",
|
||||
"print(\"Dataset preparation complete!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "XW3bmbI5-JAz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sets up the augmentations.\n",
|
||||
"# To improve your model, experiment with these settings and use more sources of\n",
|
||||
"# background clips.\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from microwakeword.audio.augmentation import Augmentation\n",
|
||||
"from microwakeword.audio.clips import Clips\n",
|
||||
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||||
"\n",
|
||||
"def validate_directories(paths):\n",
|
||||
" for path in paths:\n",
|
||||
" if not os.path.exists(path):\n",
|
||||
" print(f\"Error: Directory {path} does not exist. Please ensure preprocessing is complete.\")\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"# Paths to augmented data\n",
|
||||
"impulse_paths = ['mit_rirs']\n",
|
||||
"background_paths = ['fma_16k', 'audioset_16k']\n",
|
||||
"\n",
|
||||
"if not validate_directories(impulse_paths + background_paths):\n",
|
||||
" raise ValueError(\"One or more required directories are missing.\")\n",
|
||||
"\n",
|
||||
"clips = Clips(\n",
|
||||
" input_directory='./generated_samples',\n",
|
||||
" file_pattern='*.wav',\n",
|
||||
" max_clip_duration_s=5,\n",
|
||||
" remove_silence=True,\n",
|
||||
" random_split_seed=10,\n",
|
||||
" split_count=0.1,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"augmenter = Augmentation(\n",
|
||||
" augmentation_duration_s=3.2,\n",
|
||||
" augmentation_probabilities={\n",
|
||||
" \"SevenBandParametricEQ\": 0.1,\n",
|
||||
" \"TanhDistortion\": 0.05,\n",
|
||||
" \"PitchShift\": 0.15,\n",
|
||||
" \"BandStopFilter\": 0.1,\n",
|
||||
" \"AddColorNoise\": 0.1,\n",
|
||||
" \"AddBackgroundNoise\": 0.7,\n",
|
||||
" \"Gain\": 0.8,\n",
|
||||
" \"RIR\": 0.7,\n",
|
||||
" },\n",
|
||||
" impulse_paths=impulse_paths,\n",
|
||||
" background_paths=background_paths,\n",
|
||||
" background_min_snr_db=5,\n",
|
||||
" background_max_snr_db=10,\n",
|
||||
" min_jitter_s=0.2,\n",
|
||||
" max_jitter_s=0.3,\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "V5UsJfKKD1k9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Augment a random clip and play it back to verify it works well\n",
|
||||
"from pathlib import Path\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"from microwakeword.audio.audio_utils import save_clip\n",
|
||||
"\n",
|
||||
"# Ensure output directory exists\n",
|
||||
"output_dir = Path('./augmented_clips')\n",
|
||||
"output_dir.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Get a random clip and apply augmentation\n",
|
||||
" random_clip = clips.get_random_clip()\n",
|
||||
" augmented_clip = augmenter.augment_clip(random_clip)\n",
|
||||
" \n",
|
||||
" # Save augmented clip to file\n",
|
||||
" output_file = output_dir / 'augmented_clip.wav'\n",
|
||||
" save_clip(augmented_clip, output_file)\n",
|
||||
" print(f\"Augmented clip saved to {output_file}\")\n",
|
||||
" \n",
|
||||
" # Playback augmented clip\n",
|
||||
" display(Audio(str(output_file), autoplay=True))\n",
|
||||
"except Exception as e:\n",
|
||||
" print(f\"Error during augmentation or playback: {e}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "D7BHcY1mEGbK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Augment samples and save the training, validation, and testing sets.\n",
|
||||
"# Validating and testing samples generated the same way can make the model\n",
|
||||
"# benchmark better than it performs in real-word use. Use real samples or TTS\n",
|
||||
"# samples generated with a different TTS engine to potentially get more accurate\n",
|
||||
"# benchmarks.\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from mmap_ninja.ragged import RaggedMmap\n",
|
||||
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||||
"\n",
|
||||
"# Output directory for augmented features\n",
|
||||
"output_dir = 'generated_augmented_features'\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
"# Configuration for each split\n",
|
||||
"split_config = {\n",
|
||||
" \"training\": {\"name\": \"train\", \"repetition\": 2, \"slide_frames\": 10},\n",
|
||||
" \"validation\": {\"name\": \"validation\", \"repetition\": 1, \"slide_frames\": 10},\n",
|
||||
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Generate augmented features for each split\n",
|
||||
"for split, config in split_config.items():\n",
|
||||
" out_dir = os.path.join(output_dir, split)\n",
|
||||
" os.makedirs(out_dir, exist_ok=True)\n",
|
||||
" print(f\"Processing {split} set...\")\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" # Spectrogram generation configuration\n",
|
||||
" spectrograms = SpectrogramGeneration(\n",
|
||||
" clips=clips,\n",
|
||||
" augmenter=augmenter,\n",
|
||||
" slide_frames=config[\"slide_frames\"],\n",
|
||||
" step_ms=10, # Can parameterize this if needed\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Generate and save spectrogram features\n",
|
||||
" RaggedMmap.from_generator(\n",
|
||||
" out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n",
|
||||
" sample_generator=spectrograms.spectrogram_generator(\n",
|
||||
" split=config[\"name\"], repeat=config[\"repetition\"]\n",
|
||||
" ),\n",
|
||||
" batch_size=100, # Can parameterize this if needed\n",
|
||||
" verbose=True,\n",
|
||||
" )\n",
|
||||
" print(f\"Completed processing {split} set. Output saved to {out_dir}\")\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error processing {split} set: {e}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "1pGuJDPyp3ax"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Downloads pre-generated spectrogram features (made for microWakeWord in\n",
|
||||
"# particular) for various negative datasets. This can be slow!\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"import zipfile\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"# Function to download a file with progress bar\n",
|
||||
"def download_file(url, output_path):\n",
|
||||
" response = requests.get(url, stream=True)\n",
|
||||
" total_size = int(response.headers.get('content-length', 0))\n",
|
||||
" with open(output_path, \"wb\") as f, tqdm(\n",
|
||||
" desc=f\"Downloading {output_path.name}\",\n",
|
||||
" total=total_size,\n",
|
||||
" unit=\"B\",\n",
|
||||
" unit_scale=True,\n",
|
||||
" unit_divisor=1024,\n",
|
||||
" ) as bar:\n",
|
||||
" for chunk in response.iter_content(chunk_size=1024):\n",
|
||||
" f.write(chunk)\n",
|
||||
" bar.update(len(chunk))\n",
|
||||
" print(f\"Downloaded: {output_path}\")\n",
|
||||
"\n",
|
||||
"# Function to extract ZIP files\n",
|
||||
"def extract_zip(zip_path, extract_to):\n",
|
||||
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
||||
" zip_ref.extractall(extract_to)\n",
|
||||
" print(f\"Extracted: {zip_path} to {extract_to}\")\n",
|
||||
"\n",
|
||||
"# Directory for negative datasets\n",
|
||||
"output_dir = Path('./negative_datasets')\n",
|
||||
"output_dir.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"# Negative dataset URLs\n",
|
||||
"link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n",
|
||||
"filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n",
|
||||
"\n",
|
||||
"# Download and extract files\n",
|
||||
"for fname in filenames:\n",
|
||||
" link = link_root + fname\n",
|
||||
" zip_path = output_dir / fname\n",
|
||||
"\n",
|
||||
" # Download only if the file doesn't already exist\n",
|
||||
" if not zip_path.exists():\n",
|
||||
" try:\n",
|
||||
" download_file(link, zip_path)\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error downloading {fname}: {e}\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" # Extract the ZIP file\n",
|
||||
" try:\n",
|
||||
" extract_zip(zip_path, output_dir)\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error extracting {fname}: {e}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Ii1A14GsGVQT"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save a yaml config that controls the training process\n",
|
||||
"# These hyperparamters can make a huge different in model quality.\n",
|
||||
"# Experiment with sampling and penalty weights and increasing the number of\n",
|
||||
"# training steps.\n",
|
||||
"\n",
|
||||
"import yaml\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"config = {}\n",
|
||||
"\n",
|
||||
"config[\"window_step_ms\"] = 10\n",
|
||||
"\n",
|
||||
"config[\"train_dir\"] = \"trained_models/wakeword\"\n",
|
||||
"\n",
|
||||
"config[\"features\"] = [\n",
|
||||
" {\n",
|
||||
" \"features_dir\": \"generated_augmented_features\",\n",
|
||||
" \"sampling_weight\": 2.0, # Increased\n",
|
||||
" \"penalty_weight\": 1.0,\n",
|
||||
" \"truth\": True,\n",
|
||||
" \"truncation_strategy\": \"truncate_start\",\n",
|
||||
" \"type\": \"mmap\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"features_dir\": \"negative_datasets/speech\",\n",
|
||||
" \"sampling_weight\": 12.0, # Adjusted\n",
|
||||
" \"penalty_weight\": 1.0,\n",
|
||||
" \"truth\": False,\n",
|
||||
" \"truncation_strategy\": \"random\",\n",
|
||||
" \"type\": \"mmap\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
||||
" \"sampling_weight\": 12.0, # Adjusted\n",
|
||||
" \"penalty_weight\": 1.0,\n",
|
||||
" \"truth\": False,\n",
|
||||
" \"truncation_strategy\": \"random\",\n",
|
||||
" \"type\": \"mmap\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
||||
" \"sampling_weight\": 5.0, # Balanced\n",
|
||||
" \"penalty_weight\": 1.0,\n",
|
||||
" \"truth\": False,\n",
|
||||
" \"truncation_strategy\": \"random\",\n",
|
||||
" \"type\": \"mmap\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
||||
" \"sampling_weight\": 0.0,\n",
|
||||
" \"penalty_weight\": 1.0,\n",
|
||||
" \"truth\": False,\n",
|
||||
" \"truncation_strategy\": \"split\",\n",
|
||||
" \"type\": \"mmap\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"config[\"training_steps\"] = [40000] # Increased\n",
|
||||
"config[\"positive_class_weight\"] = [1]\n",
|
||||
"config[\"negative_class_weight\"] = [20] # Adjusted\n",
|
||||
"config[\"learning_rates\"] = [0.001] # Adjusted\n",
|
||||
"config[\"batch_size\"] = 128\n",
|
||||
"\n",
|
||||
"config[\"time_mask_max_size\"] = [0] # Enabled SpecAugment\n",
|
||||
"config[\"time_mask_count\"] = [0]\n",
|
||||
"config[\"freq_mask_max_size\"] = [0]\n",
|
||||
"config[\"freq_mask_count\"] = [0]\n",
|
||||
"\n",
|
||||
"config[\"eval_step_interval\"] = 500 # Adjusted\n",
|
||||
"config[\"clip_duration_ms\"] = 1500 # Increased\n",
|
||||
"\n",
|
||||
"config[\"target_minimization\"] = 0.9\n",
|
||||
"config[\"minimization_metric\"] = None # Updated\n",
|
||||
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
||||
"\n",
|
||||
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
||||
" documents = yaml.dump(config, file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WoEXJBaiC9mf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Trains a model. When finished, it will quantize and convert the model to a\n",
|
||||
"# streaming version suitable for on-device detection.\n",
|
||||
"# It will resume if stopped, but it will start over at the configured training\n",
|
||||
"# steps in the yaml file.\n",
|
||||
"# Change --train 0 to only convert and test the best-weighted model.\n",
|
||||
"# On Google colab, it doesn't print the mini-batch results, so it may appear\n",
|
||||
"# stuck for several minutes! Additionally, it is very slow compared to training\n",
|
||||
"# on a local GPU.\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"# Ensure the library path is correctly set\n",
|
||||
"os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n",
|
||||
"\n",
|
||||
"# Training command with optimized settings\n",
|
||||
"!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
|
||||
"--training_config='training_parameters.yaml' \\\n",
|
||||
"--train 1 \\\n",
|
||||
"--restore_checkpoint 1 \\\n",
|
||||
"--test_tf_nonstreaming 0 \\\n",
|
||||
"--test_tflite_nonstreaming 0 \\\n",
|
||||
"--test_tflite_nonstreaming_quantized 0 \\\n",
|
||||
"--test_tflite_streaming 0 \\\n",
|
||||
"--test_tflite_streaming_quantized 1 \\\n",
|
||||
"--use_weights \"best_weights\" \\\n",
|
||||
"mixednet \\\n",
|
||||
"--pointwise_filters \"64,64,64,64\" \\\n",
|
||||
"--repeat_in_block \"1,1,1,1\" \\\n",
|
||||
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n",
|
||||
"--residual_connection \"0,0,0,0\" \\\n",
|
||||
"--first_conv_filters 32 \\\n",
|
||||
"--first_conv_kernel_size 5 \\\n",
|
||||
"--stride 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ex_UIWvwtjAN"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import shutil\n",
|
||||
"import json\n",
|
||||
"from IPython.display import FileLink\n",
|
||||
"\n",
|
||||
"# Define the source path and desired download location for the TFLite file\n",
|
||||
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
||||
"destination_path = \"./stream_state_internal_quant.tflite\"\n",
|
||||
"\n",
|
||||
"# Copy the TFLite file to the current working directory\n",
|
||||
"shutil.copy(source_path, destination_path)\n",
|
||||
"\n",
|
||||
"# Define the JSON file content\n",
|
||||
"json_data = {\n",
|
||||
" \"type\": \"micro\",\n",
|
||||
" \"wake_word\": \"hey_norman\", # Adjust this if the target_word changes dynamically\n",
|
||||
" \"author\": \"master phooey\",\n",
|
||||
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
||||
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
||||
" \"trained_languages\": [\"en\"],\n",
|
||||
" \"version\": 2,\n",
|
||||
" \"micro\": {\n",
|
||||
" \"probability_cutoff\": 0.97,\n",
|
||||
" \"sliding_window_size\": 5,\n",
|
||||
" \"feature_step_size\": 10,\n",
|
||||
" \"tensor_arena_size\": 30000,\n",
|
||||
" \"minimum_esphome_version\": \"2024.7.0\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Define the JSON file path\n",
|
||||
"json_path = \"./stream_state_internal_quant.json\"\n",
|
||||
"\n",
|
||||
"# Write the JSON file\n",
|
||||
"with open(json_path, \"w\") as json_file:\n",
|
||||
" json.dump(json_data, json_file, indent=2)\n",
|
||||
"\n",
|
||||
"# Generate download links for both files\n",
|
||||
"print(\"Download your files:\")\n",
|
||||
"print(\"TFLite Model:\")\n",
|
||||
"display(FileLink(destination_path))\n",
|
||||
"print(\"\\nJSON Metadata:\")\n",
|
||||
"display(FileLink(json_path))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Reference in New Issue
Block a user