Files
microWakeWord-Trainer-Nvidi…/cli/wake_word_sample_augmenter
George Joseph cb81f7f02d Train from the command line
The files in the `cli` directory allow you to train wake words
from the command line without needing to use the Jupyter notebook
or a web browser.  Basically, the logic from the notebook has been
placed in separate shell scripts and python files wrapped by 3 high-level
scripts that do the following:

* setup_python_venv: Creates a Python virtual environment with all the
packages needed to train.  The venv is created in the container's /data
directory and is therefore stored on the host, not in the container's root
docker volume.

* setup_training_datasets: Downloads, extracts and converts the MIT RIR,
FMA, Audioset and Negative training reference datasets.  Also stored in /data.

* train_wake_word: Generates the wake word samples, augments them with the
audio from the training datasets, and finally runs the microwakeword training.
The resulting model tflite and json files are placed in the /data/output
directory.

See the README.md file for much more information.
2025-12-28 12:48:51 -07:00

216 lines
7.3 KiB
Python
Executable File

#!/usr/bin/env python
import sys, os, gc, glob, random
import types, shutil, json
from datetime import datetime, timezone
from pathlib import Path
from argparse import ArgumentParser as ArgParser, ArgumentError
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
parser = ArgParser(exit_on_error=False)
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
parser.add_argument("--input-dir", type=str, help="Sample input directory. Default: <data-dir>/work/wake_word_samples", required=False)
parser.add_argument("--output-dir", type=str, help="Sample output directory. Default: <input-dir>_augmented", required=False)
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
try:
args = parser.parse_args()
except ArgumentError:
parser.print_help()
sys.exit(1)
args.data_dir = os.path.realpath(args.data_dir)
work_dir = args.data_dir + "/work"
if not args.input_dir:
args.input_dir = work_dir + "/wake_word_samples"
else:
args.input_dir = os.path.realpath(args.input_dir)
if not args.output_dir:
args.output_dir = args.input_dir + "_augmented"
else:
args.output_dir = os.path.realpath(args.output_dir)
if not args.mit_rirs_16k_dir:
args.mit_rirs_16k_dir = args.data_dir + "/training_datasets/mit_rirs_16k"
else:
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
if not args.fma_16k_dir:
args.fma_16k_dir = args.data_dir + "/training_datasets/fma_16k"
else:
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
if not args.audioset_16k_dir:
args.audioset_16k_dir = args.data_dir + "/training_datasets/audioset_16k"
else:
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
out_path = Path(args.output_dir)
out_path.mkdir(exist_ok=True)
def validate_directories(paths):
for path in paths:
if not os.path.exists(path):
print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
return False
return True
paths = [ work_dir, args.input_dir, args.output_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir ]
if not validate_directories(paths):
parser.print_help()
sys.exit(1)
files = glob.glob(args.input_dir + "/*.wav")
if not files:
raise RuntimeError("❌ No WAVs in wake_word_samples.")
max_samples = len(files)
print(f"\n===== Augmenting {max_samples} wake word samples =====")
print(" Initializing libraries")
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
os.environ["TF_XLA_FLAGS"]="--tf_xla_auto_jit=0"
os.environ["NVIDIA_TF32_OVERRIDE"]="1"
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"]="512"
os.environ["GLOG_minloglevel"]="9"
os.environ["GRPC_VERBOSITY"]="ERROR"
print(" Loading Tensorflow")
import tensorflow as tf
print(" GPU memory config")
# Per-device memory growth (belt + suspenders)
for g in tf.config.list_physical_devices("GPU"):
try:
tf.config.experimental.set_memory_growth(g, True)
except Exception:
pass
print(f" GPUs: {tf.config.list_physical_devices('GPU')}")
gc.collect()
import numpy as np
import librosa
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration
from microwakeword.audio.audio_utils import save_clip
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
# Paths to augmented data
impulse_paths = [ args.mit_rirs_16k_dir ]
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir]
clips = Clips(
input_directory=args.input_dir,
file_pattern='*.wav',
max_clip_duration_s=5,
remove_silence=True,
random_split_seed=10,
split_count=0.1,
)
augmenter = Augmentation(
augmentation_duration_s=3.2,
augmentation_probabilities={
"SevenBandParametricEQ": 0.1,
"TanhDistortion": 0.05,
"PitchShift": 0.15,
"BandStopFilter": 0.1,
"AddColorNoise": 0.1,
"AddBackgroundNoise": 0.7,
"Gain": 0.8,
"RIR": 0.7,
},
impulse_paths=impulse_paths,
background_paths=background_paths,
background_min_snr_db=5,
background_max_snr_db=10,
min_jitter_s=0.2,
max_jitter_s=0.3,
)
# Augment samples and save the training, validation, and testing sets.
def audio_generator_from_wavs(self, split="train", repeat=1):
"""
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
"""
files = sorted(glob.glob(args.input_dir + "/*.wav"))
if not files:
raise RuntimeError("❌ No WAVs in wake_word_samples.")
rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)
files_shuf = files[:]
rng.shuffle(files_shuf)
n = len(files_shuf)
n_val = max(1, int(0.10 * n))
n_test = max(1, int(0.10 * n))
n_train = max(0, n - n_val - n_test)
splits = {
"train": files_shuf[:n_train],
"validation": files_shuf[n_train:n_train + n_val],
"test": files_shuf[n_train + n_val:],
}
file_list = splits.get(split, [])
if not file_list:
return # nothing to yield
for _ in range(max(1, int(repeat))):
for p in file_list:
y, sr = librosa.load(p, sr=16000, mono=True)
yield y.astype(np.float32, copy=False)
# Bind the patched generator to your existing `clips` instance
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
# ---- Split config (same as before) ----
split_cfg = {
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
}
# ---- Generate features ----
for split, cfg in split_cfg.items():
out_dir = out_path / split
out_dir.mkdir(parents=True, exist_ok=True)
print(f" Augmenting {split}")
print(f" Generating spectrograms")
spectros = SpectrogramGeneration(
clips=clips, # now backed by our WAV loader
augmenter=augmenter, # your existing augmenter
slide_frames=cfg["slide_frames"],
step_ms=10,
)
print(f" Generating files")
RaggedMmap.from_generator(
out_dir=str(out_dir / "wakeword_mmap"),
sample_generator=spectros.spectrogram_generator(
split=cfg["name"], repeat=cfg["repetition"]
),
batch_size=100,
verbose=False,
)
print(f" {split} augmentation complete")
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
et = END_TIME - START_TIME
print(f"\n{'=' * 80}")
msg=f"Augmented {max_samples} wake word samples."
print(f"{msg:>50s} Elapsed time: {et!s}")
print(f"{'=' * 80}\n")