Files
microWakeWord-Trainer-Nvidi…/cli/wake_word_sample_augmenter
2026-03-09 19:48:35 -05:00

274 lines
9.2 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
import sys, os, gc, glob, random
import types
from datetime import datetime, timezone
from pathlib import Path
from argparse import ArgumentParser as ArgParser, ArgumentError
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
parser = ArgParser(exit_on_error=False)
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
# Wake word (TTS/generated) inputs/outputs
parser.add_argument("--input-dir", type=str, help="Wake word input dir. Default: <data-dir>/work/wake_word_samples", required=False)
parser.add_argument("--output-dir", type=str, help="Wake word output dir. Default: <input-dir>_augmented", required=False)
# Personal inputs/outputs (NEW)
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
# Dataset dirs
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
parser.add_argument("--wham-16k-dir", type=str, help="WHAM input directory. Default: <data-dir>/training_datasets/wham_16k", required=False)
parser.add_argument("--chime-16k-dir", type=str, help="CHiME input directory. Default: <data-dir>/training_datasets/chime_16k", required=False)
try:
args = parser.parse_args()
except ArgumentError:
parser.print_help()
sys.exit(1)
args.data_dir = os.path.realpath(args.data_dir)
work_dir = os.path.join(args.data_dir, "work")
# Wake word defaults
if not args.input_dir:
args.input_dir = os.path.join(work_dir, "wake_word_samples")
else:
args.input_dir = os.path.realpath(args.input_dir)
if not args.output_dir:
args.output_dir = args.input_dir + "_augmented"
else:
args.output_dir = os.path.realpath(args.output_dir)
# Personal defaults (NEW)
if not args.personal_dir:
args.personal_dir = os.path.join(args.data_dir, "personal_samples")
else:
args.personal_dir = os.path.realpath(args.personal_dir)
if not args.personal_output_dir:
args.personal_output_dir = os.path.join(work_dir, "personal_augmented_features")
else:
args.personal_output_dir = os.path.realpath(args.personal_output_dir)
# Dataset defaults
if not args.mit_rirs_16k_dir:
args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
else:
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
if not args.fma_16k_dir:
args.fma_16k_dir = os.path.join(args.data_dir, "training_datasets", "fma_16k")
else:
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
if not args.audioset_16k_dir:
args.audioset_16k_dir = os.path.join(args.data_dir, "training_datasets", "audioset_16k")
else:
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
if not args.wham_16k_dir:
args.wham_16k_dir = os.path.join(args.data_dir, "training_datasets", "wham_16k")
else:
args.wham_16k_dir = os.path.realpath(args.wham_16k_dir)
if not args.chime_16k_dir:
args.chime_16k_dir = os.path.join(args.data_dir, "training_datasets", "chime_16k")
else:
args.chime_16k_dir = os.path.realpath(args.chime_16k_dir)
def validate_directories(paths):
for path in paths:
if not os.path.exists(path):
print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
return False
return True
required = [
work_dir,
args.input_dir,
args.mit_rirs_16k_dir,
args.wham_16k_dir,
args.chime_16k_dir,
args.fma_16k_dir,
args.audioset_16k_dir,
]
if not validate_directories(required):
parser.print_help()
sys.exit(1)
# -------------------- TF + libs --------------------
print(" Initializing libraries")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"] = "512"
os.environ["GLOG_minloglevel"] = "9"
os.environ["GRPC_VERBOSITY"] = "ERROR"
print(" Loading Tensorflow")
import tensorflow as tf
print(" GPU memory config")
for g in tf.config.list_physical_devices("GPU"):
try:
tf.config.experimental.set_memory_growth(g, True)
except Exception:
pass
print(f" GPUs: {tf.config.list_physical_devices('GPU')}")
gc.collect()
import numpy as np
import librosa
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
impulse_paths = [args.mit_rirs_16k_dir]
background_paths = [
args.wham_16k_dir,
args.chime_16k_dir,
args.fma_16k_dir,
args.audioset_16k_dir,
]
augmenter = Augmentation(
augmentation_duration_s=3.2,
augmentation_probabilities={
"SevenBandParametricEQ": 0.1,
"TanhDistortion": 0.05,
"PitchShift": 0.15,
"BandStopFilter": 0.1,
"AddColorNoise": 0.1,
"AddBackgroundNoise": 0.7,
"Gain": 0.8,
"RIR": 0.7,
},
impulse_paths=impulse_paths,
background_paths=background_paths,
background_min_snr_db=5,
background_max_snr_db=10,
min_jitter_s=0.2,
max_jitter_s=0.3,
)
split_cfg = {
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
}
def bind_wav_generator(clips_obj: Clips, wav_dir: str):
"""
Patch clips.audio_generator so we load WAVs directly (deterministic 80/10/10 split, seed=10).
Matches the notebook behavior you posted.
"""
def audio_generator_from_wavs(self, split="train", repeat=1):
files = sorted(glob.glob(os.path.join(wav_dir, "*.wav")))
if not files:
return
rng = random.Random(10)
files_shuf = files[:]
rng.shuffle(files_shuf)
n = len(files_shuf)
n_val = max(1, int(0.10 * n))
n_test = max(1, int(0.10 * n))
n_train = max(0, n - n_val - n_test)
splits = {
"train": files_shuf[:n_train],
"validation": files_shuf[n_train:n_train + n_val],
"test": files_shuf[n_train + n_val:],
}
file_list = splits.get(split, [])
if not file_list:
return
for _ in range(max(1, int(repeat))):
for p in file_list:
y, _sr = librosa.load(p, sr=16000, mono=True)
yield y.astype(np.float32, copy=False)
clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
if not files:
print(f" No WAVs found for {label} in: {input_wav_dir} (skipping)")
return False
max_samples = len(files)
print(f"\n===== Augmenting {max_samples} wake word samples ({label}) =====")
clips = Clips(
input_directory=input_wav_dir,
file_pattern="*.wav",
max_clip_duration_s=5,
remove_silence=True,
random_split_seed=10,
split_count=0.1,
)
bind_wav_generator(clips, input_wav_dir)
out_root = Path(out_root_dir)
out_root.mkdir(parents=True, exist_ok=True)
for split, cfg in split_cfg.items():
out_dir = out_root / split
out_dir.mkdir(parents=True, exist_ok=True)
print(f" Augmenting {split} ({label})")
print(" Sit tight this can take awhile ...")
print()
spectros = SpectrogramGeneration(
clips=clips,
augmenter=augmenter,
slide_frames=cfg["slide_frames"],
step_ms=10,
)
gen = spectros.spectrogram_generator(
split=cfg["name"],
repeat=cfg["repetition"],
)
RaggedMmap.from_generator(
out_dir=str(out_dir / "wakeword_mmap"),
sample_generator=gen,
batch_size=100,
verbose=False,
)
print(f" {split} augmentation complete ({label})")
print(f"\n✅ Features ready: {out_root_dir}/*/wakeword_mmap\n")
return True
# Wake word generated/TTS features (existing behavior)
generate_feature_set(args.input_dir, args.output_dir, "generated")
# Personal features (NEW)
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
et = END_TIME - START_TIME
print(f"\n{'=' * 80}")
print(f"{'Augmentation completed.':>50s} Elapsed time: {et!s}")
print(f"{'=' * 80}\n")