Files
microWakeWord-Trainer-Nvidi…/cli/wake_word_sample_augmenter
2026-01-17 01:23:51 -06:00

216 lines
7.3 KiB
Python

#!/usr/bin/env python
import sys, os, gc, glob, random
import types, shutil, json
from datetime import datetime, timezone
from pathlib import Path
from argparse import ArgumentParser as ArgParser, ArgumentError
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
parser = ArgParser(exit_on_error=False)
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
parser.add_argument("--input-dir", type=str, help="Sample input directory. Default: <data-dir>/work/wake_word_samples", required=False)
parser.add_argument("--output-dir", type=str, help="Sample output directory. Default: <input-dir>_augmented", required=False)
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
try:
args = parser.parse_args()
except ArgumentError:
parser.print_help()
sys.exit(1)
args.data_dir = os.path.realpath(args.data_dir)
work_dir = args.data_dir + "/work"
if not args.input_dir:
args.input_dir = work_dir + "/wake_word_samples"
else:
args.input_dir = os.path.realpath(args.input_dir)
if not args.output_dir:
args.output_dir = args.input_dir + "_augmented"
else:
args.output_dir = os.path.realpath(args.output_dir)
if not args.mit_rirs_16k_dir:
args.mit_rirs_16k_dir = args.data_dir + "/training_datasets/mit_rirs_16k"
else:
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
if not args.fma_16k_dir:
args.fma_16k_dir = args.data_dir + "/training_datasets/fma_16k"
else:
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
if not args.audioset_16k_dir:
args.audioset_16k_dir = args.data_dir + "/training_datasets/audioset_16k"
else:
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
out_path = Path(args.output_dir)
out_path.mkdir(exist_ok=True)
def validate_directories(paths):
for path in paths:
if not os.path.exists(path):
print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
return False
return True
paths = [ work_dir, args.input_dir, args.output_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir ]
if not validate_directories(paths):
parser.print_help()
sys.exit(1)
files = glob.glob(args.input_dir + "/*.wav")
if not files:
raise RuntimeError("❌ No WAVs in wake_word_samples.")
max_samples = len(files)
print(f"\n===== Augmenting {max_samples} wake word samples =====")
print(" Initializing libraries")
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
os.environ["TF_XLA_FLAGS"]="--tf_xla_auto_jit=0"
os.environ["NVIDIA_TF32_OVERRIDE"]="1"
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"]="512"
os.environ["GLOG_minloglevel"]="9"
os.environ["GRPC_VERBOSITY"]="ERROR"
print(" Loading Tensorflow")
import tensorflow as tf
print(" GPU memory config")
# Per-device memory growth (belt + suspenders)
for g in tf.config.list_physical_devices("GPU"):
try:
tf.config.experimental.set_memory_growth(g, True)
except Exception:
pass
print(f" GPUs: {tf.config.list_physical_devices('GPU')}")
gc.collect()
import numpy as np
import librosa
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration
from microwakeword.audio.audio_utils import save_clip
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
# Paths to augmented data
impulse_paths = [ args.mit_rirs_16k_dir ]
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir]
clips = Clips(
input_directory=args.input_dir,
file_pattern='*.wav',
max_clip_duration_s=5,
remove_silence=True,
random_split_seed=10,
split_count=0.1,
)
augmenter = Augmentation(
augmentation_duration_s=3.2,
augmentation_probabilities={
"SevenBandParametricEQ": 0.1,
"TanhDistortion": 0.05,
"PitchShift": 0.15,
"BandStopFilter": 0.1,
"AddColorNoise": 0.1,
"AddBackgroundNoise": 0.7,
"Gain": 0.8,
"RIR": 0.7,
},
impulse_paths=impulse_paths,
background_paths=background_paths,
background_min_snr_db=5,
background_max_snr_db=10,
min_jitter_s=0.2,
max_jitter_s=0.3,
)
# Augment samples and save the training, validation, and testing sets.
def audio_generator_from_wavs(self, split="train", repeat=1):
"""
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
"""
files = sorted(glob.glob(args.input_dir + "/*.wav"))
if not files:
raise RuntimeError("❌ No WAVs in wake_word_samples.")
rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)
files_shuf = files[:]
rng.shuffle(files_shuf)
n = len(files_shuf)
n_val = max(1, int(0.10 * n))
n_test = max(1, int(0.10 * n))
n_train = max(0, n - n_val - n_test)
splits = {
"train": files_shuf[:n_train],
"validation": files_shuf[n_train:n_train + n_val],
"test": files_shuf[n_train + n_val:],
}
file_list = splits.get(split, [])
if not file_list:
return # nothing to yield
for _ in range(max(1, int(repeat))):
for p in file_list:
y, sr = librosa.load(p, sr=16000, mono=True)
yield y.astype(np.float32, copy=False)
# Bind the patched generator to your existing `clips` instance
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
# ---- Split config (same as before) ----
split_cfg = {
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
}
# ---- Generate features ----
for split, cfg in split_cfg.items():
out_dir = out_path / split
out_dir.mkdir(parents=True, exist_ok=True)
print(f" Augmenting {split}")
print(f" Generating spectrograms")
spectros = SpectrogramGeneration(
clips=clips, # now backed by our WAV loader
augmenter=augmenter, # your existing augmenter
slide_frames=cfg["slide_frames"],
step_ms=10,
)
print(f" Generating files")
RaggedMmap.from_generator(
out_dir=str(out_dir / "wakeword_mmap"),
sample_generator=spectros.spectrogram_generator(
split=cfg["name"], repeat=cfg["repetition"]
),
batch_size=100,
verbose=False,
)
print(f" {split} augmentation complete")
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
et = END_TIME - START_TIME
print(f"\n{'=' * 80}")
msg=f"Augmented {max_samples} wake word samples."
print(f"{msg:>50s} Elapsed time: {et!s}")
print(f"{'=' * 80}\n")