mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
The files in the `cli` directory allow you to train wake words from the command line without needing to use the Jupyter notebook or a web browser. Basically, the logic from the notebook has been placed in separate shell scripts and python files wrapped by 3 high-level scripts that do the following: * setup_python_venv: Creates a Python virtual environment with all the packages needed to train. The venv is created in the container's /data directory and is therefore stored on the host, not in the container's root docker volume. * setup_training_datasets: Downloads, extracts and converts the MIT RIR, FMA, Audioset and Negative training reference datasets. Also stored in /data. * train_wake_word: Generates the wake word samples, augments them with the audio from the training datasets, and finally runs the microwakeword training. The resulting model tflite and json files are placed in the /data/output directory. See the README.md file for much more information.
216 lines
7.3 KiB
Python
Executable File
216 lines
7.3 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import sys, os, gc, glob, random
|
|
import types, shutil, json
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from argparse import ArgumentParser as ArgParser, ArgumentError
|
|
|
|
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
|
|
|
|
parser = ArgParser(exit_on_error=False)
|
|
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
|
parser.add_argument("--input-dir", type=str, help="Sample input directory. Default: <data-dir>/work/wake_word_samples", required=False)
|
|
parser.add_argument("--output-dir", type=str, help="Sample output directory. Default: <input-dir>_augmented", required=False)
|
|
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
|
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
|
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
|
|
|
try:
|
|
args = parser.parse_args()
|
|
except ArgumentError:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
args.data_dir = os.path.realpath(args.data_dir)
|
|
work_dir = args.data_dir + "/work"
|
|
|
|
if not args.input_dir:
|
|
args.input_dir = work_dir + "/wake_word_samples"
|
|
else:
|
|
args.input_dir = os.path.realpath(args.input_dir)
|
|
|
|
if not args.output_dir:
|
|
args.output_dir = args.input_dir + "_augmented"
|
|
else:
|
|
args.output_dir = os.path.realpath(args.output_dir)
|
|
|
|
if not args.mit_rirs_16k_dir:
|
|
args.mit_rirs_16k_dir = args.data_dir + "/training_datasets/mit_rirs_16k"
|
|
else:
|
|
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
|
|
|
|
if not args.fma_16k_dir:
|
|
args.fma_16k_dir = args.data_dir + "/training_datasets/fma_16k"
|
|
else:
|
|
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
|
|
|
|
if not args.audioset_16k_dir:
|
|
args.audioset_16k_dir = args.data_dir + "/training_datasets/audioset_16k"
|
|
else:
|
|
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
|
|
|
|
out_path = Path(args.output_dir)
|
|
out_path.mkdir(exist_ok=True)
|
|
|
|
def validate_directories(paths):
|
|
for path in paths:
|
|
if not os.path.exists(path):
|
|
print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
|
|
return False
|
|
return True
|
|
|
|
paths = [ work_dir, args.input_dir, args.output_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir ]
|
|
if not validate_directories(paths):
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
files = glob.glob(args.input_dir + "/*.wav")
|
|
if not files:
|
|
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
|
max_samples = len(files)
|
|
|
|
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
|
|
|
print(" Initializing libraries")
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"
|
|
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
|
|
os.environ["TF_XLA_FLAGS"]="--tf_xla_auto_jit=0"
|
|
os.environ["NVIDIA_TF32_OVERRIDE"]="1"
|
|
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"]="512"
|
|
os.environ["GLOG_minloglevel"]="9"
|
|
os.environ["GRPC_VERBOSITY"]="ERROR"
|
|
|
|
print(" Loading Tensorflow")
|
|
import tensorflow as tf
|
|
|
|
print(" GPU memory config")
|
|
# Per-device memory growth (belt + suspenders)
|
|
for g in tf.config.list_physical_devices("GPU"):
|
|
try:
|
|
tf.config.experimental.set_memory_growth(g, True)
|
|
except Exception:
|
|
pass
|
|
print(f" GPUs: {tf.config.list_physical_devices('GPU')}")
|
|
gc.collect()
|
|
|
|
import numpy as np
|
|
import librosa
|
|
from mmap_ninja.ragged import RaggedMmap
|
|
from microwakeword.audio.augmentation import Augmentation
|
|
from microwakeword.audio.clips import Clips
|
|
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
|
from microwakeword.audio.audio_utils import save_clip
|
|
|
|
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
|
|
|
# Paths to augmented data
|
|
impulse_paths = [ args.mit_rirs_16k_dir ]
|
|
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir]
|
|
|
|
clips = Clips(
|
|
input_directory=args.input_dir,
|
|
file_pattern='*.wav',
|
|
max_clip_duration_s=5,
|
|
remove_silence=True,
|
|
random_split_seed=10,
|
|
split_count=0.1,
|
|
)
|
|
|
|
augmenter = Augmentation(
|
|
augmentation_duration_s=3.2,
|
|
augmentation_probabilities={
|
|
"SevenBandParametricEQ": 0.1,
|
|
"TanhDistortion": 0.05,
|
|
"PitchShift": 0.15,
|
|
"BandStopFilter": 0.1,
|
|
"AddColorNoise": 0.1,
|
|
"AddBackgroundNoise": 0.7,
|
|
"Gain": 0.8,
|
|
"RIR": 0.7,
|
|
},
|
|
impulse_paths=impulse_paths,
|
|
background_paths=background_paths,
|
|
background_min_snr_db=5,
|
|
background_max_snr_db=10,
|
|
min_jitter_s=0.2,
|
|
max_jitter_s=0.3,
|
|
)
|
|
|
|
# Augment samples and save the training, validation, and testing sets.
|
|
|
|
def audio_generator_from_wavs(self, split="train", repeat=1):
|
|
"""
|
|
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
|
Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
|
|
"""
|
|
files = sorted(glob.glob(args.input_dir + "/*.wav"))
|
|
if not files:
|
|
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
|
|
|
rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)
|
|
files_shuf = files[:]
|
|
rng.shuffle(files_shuf)
|
|
|
|
n = len(files_shuf)
|
|
n_val = max(1, int(0.10 * n))
|
|
n_test = max(1, int(0.10 * n))
|
|
n_train = max(0, n - n_val - n_test)
|
|
splits = {
|
|
"train": files_shuf[:n_train],
|
|
"validation": files_shuf[n_train:n_train + n_val],
|
|
"test": files_shuf[n_train + n_val:],
|
|
}
|
|
file_list = splits.get(split, [])
|
|
if not file_list:
|
|
return # nothing to yield
|
|
|
|
for _ in range(max(1, int(repeat))):
|
|
for p in file_list:
|
|
y, sr = librosa.load(p, sr=16000, mono=True)
|
|
yield y.astype(np.float32, copy=False)
|
|
|
|
# Bind the patched generator to your existing `clips` instance
|
|
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
|
|
|
# ---- Split config (same as before) ----
|
|
split_cfg = {
|
|
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
|
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
|
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
|
|
}
|
|
|
|
# ---- Generate features ----
|
|
for split, cfg in split_cfg.items():
|
|
out_dir = out_path / split
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
print(f" Augmenting {split}")
|
|
|
|
print(f" Generating spectrograms")
|
|
spectros = SpectrogramGeneration(
|
|
clips=clips, # now backed by our WAV loader
|
|
augmenter=augmenter, # your existing augmenter
|
|
slide_frames=cfg["slide_frames"],
|
|
step_ms=10,
|
|
)
|
|
|
|
print(f" Generating files")
|
|
RaggedMmap.from_generator(
|
|
out_dir=str(out_dir / "wakeword_mmap"),
|
|
sample_generator=spectros.spectrogram_generator(
|
|
split=cfg["name"], repeat=cfg["repetition"]
|
|
),
|
|
batch_size=100,
|
|
verbose=False,
|
|
)
|
|
print(f" {split} augmentation complete")
|
|
|
|
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
|
et = END_TIME - START_TIME
|
|
print(f"\n{'=' * 80}")
|
|
msg=f"Augmented {max_samples} wake word samples."
|
|
print(f"{msg:>50s} Elapsed time: {et!s}")
|
|
print(f"{'=' * 80}\n")
|