mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
216 lines
7.3 KiB
Python
216 lines
7.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
import sys, os, gc, glob, random
|
|
import types, shutil, json
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from argparse import ArgumentParser as ArgParser, ArgumentError
|
|
|
|
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
|
|
|
|
parser = ArgParser(exit_on_error=False)
|
|
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
|
parser.add_argument("--input-dir", type=str, help="Sample input directory. Default: <data-dir>/work/wake_word_samples", required=False)
|
|
parser.add_argument("--output-dir", type=str, help="Sample output directory. Default: <input-dir>_augmented", required=False)
|
|
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
|
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
|
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
|
|
|
try:
|
|
args = parser.parse_args()
|
|
except ArgumentError:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
args.data_dir = os.path.realpath(args.data_dir)
|
|
work_dir = args.data_dir + "/work"
|
|
|
|
if not args.input_dir:
|
|
args.input_dir = work_dir + "/wake_word_samples"
|
|
else:
|
|
args.input_dir = os.path.realpath(args.input_dir)
|
|
|
|
if not args.output_dir:
|
|
args.output_dir = args.input_dir + "_augmented"
|
|
else:
|
|
args.output_dir = os.path.realpath(args.output_dir)
|
|
|
|
if not args.mit_rirs_16k_dir:
|
|
args.mit_rirs_16k_dir = args.data_dir + "/training_datasets/mit_rirs_16k"
|
|
else:
|
|
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
|
|
|
|
if not args.fma_16k_dir:
|
|
args.fma_16k_dir = args.data_dir + "/training_datasets/fma_16k"
|
|
else:
|
|
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
|
|
|
|
if not args.audioset_16k_dir:
|
|
args.audioset_16k_dir = args.data_dir + "/training_datasets/audioset_16k"
|
|
else:
|
|
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
|
|
|
|
out_path = Path(args.output_dir)
|
|
out_path.mkdir(exist_ok=True)
|
|
|
|
def validate_directories(paths):
|
|
for path in paths:
|
|
if not os.path.exists(path):
|
|
print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
|
|
return False
|
|
return True
|
|
|
|
paths = [ work_dir, args.input_dir, args.output_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir ]
|
|
if not validate_directories(paths):
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
files = glob.glob(args.input_dir + "/*.wav")
|
|
if not files:
|
|
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
|
max_samples = len(files)
|
|
|
|
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
|
|
|
print(" Initializing libraries")
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"
|
|
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
|
|
os.environ["TF_XLA_FLAGS"]="--tf_xla_auto_jit=0"
|
|
os.environ["NVIDIA_TF32_OVERRIDE"]="1"
|
|
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"]="512"
|
|
os.environ["GLOG_minloglevel"]="9"
|
|
os.environ["GRPC_VERBOSITY"]="ERROR"
|
|
|
|
print(" Loading Tensorflow")
|
|
import tensorflow as tf
|
|
|
|
print(" GPU memory config")
|
|
# Per-device memory growth (belt + suspenders)
|
|
for g in tf.config.list_physical_devices("GPU"):
|
|
try:
|
|
tf.config.experimental.set_memory_growth(g, True)
|
|
except Exception:
|
|
pass
|
|
print(f" GPUs: {tf.config.list_physical_devices('GPU')}")
|
|
gc.collect()
|
|
|
|
import numpy as np
|
|
import librosa
|
|
from mmap_ninja.ragged import RaggedMmap
|
|
from microwakeword.audio.augmentation import Augmentation
|
|
from microwakeword.audio.clips import Clips
|
|
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
|
from microwakeword.audio.audio_utils import save_clip
|
|
|
|
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
|
|
|
# Paths to augmented data
|
|
impulse_paths = [ args.mit_rirs_16k_dir ]
|
|
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir]
|
|
|
|
clips = Clips(
|
|
input_directory=args.input_dir,
|
|
file_pattern='*.wav',
|
|
max_clip_duration_s=5,
|
|
remove_silence=True,
|
|
random_split_seed=10,
|
|
split_count=0.1,
|
|
)
|
|
|
|
augmenter = Augmentation(
|
|
augmentation_duration_s=3.2,
|
|
augmentation_probabilities={
|
|
"SevenBandParametricEQ": 0.1,
|
|
"TanhDistortion": 0.05,
|
|
"PitchShift": 0.15,
|
|
"BandStopFilter": 0.1,
|
|
"AddColorNoise": 0.1,
|
|
"AddBackgroundNoise": 0.7,
|
|
"Gain": 0.8,
|
|
"RIR": 0.7,
|
|
},
|
|
impulse_paths=impulse_paths,
|
|
background_paths=background_paths,
|
|
background_min_snr_db=5,
|
|
background_max_snr_db=10,
|
|
min_jitter_s=0.2,
|
|
max_jitter_s=0.3,
|
|
)
|
|
|
|
# Augment samples and save the training, validation, and testing sets.
|
|
|
|
def audio_generator_from_wavs(self, split="train", repeat=1):
|
|
"""
|
|
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
|
Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
|
|
"""
|
|
files = sorted(glob.glob(args.input_dir + "/*.wav"))
|
|
if not files:
|
|
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
|
|
|
rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)
|
|
files_shuf = files[:]
|
|
rng.shuffle(files_shuf)
|
|
|
|
n = len(files_shuf)
|
|
n_val = max(1, int(0.10 * n))
|
|
n_test = max(1, int(0.10 * n))
|
|
n_train = max(0, n - n_val - n_test)
|
|
splits = {
|
|
"train": files_shuf[:n_train],
|
|
"validation": files_shuf[n_train:n_train + n_val],
|
|
"test": files_shuf[n_train + n_val:],
|
|
}
|
|
file_list = splits.get(split, [])
|
|
if not file_list:
|
|
return # nothing to yield
|
|
|
|
for _ in range(max(1, int(repeat))):
|
|
for p in file_list:
|
|
y, sr = librosa.load(p, sr=16000, mono=True)
|
|
yield y.astype(np.float32, copy=False)
|
|
|
|
# Bind the patched generator to your existing `clips` instance
|
|
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
|
|
|
# ---- Split config (same as before) ----
|
|
split_cfg = {
|
|
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
|
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
|
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
|
|
}
|
|
|
|
# ---- Generate features ----
|
|
for split, cfg in split_cfg.items():
|
|
out_dir = out_path / split
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
print(f" Augmenting {split}")
|
|
|
|
print(f" Generating spectrograms")
|
|
spectros = SpectrogramGeneration(
|
|
clips=clips, # now backed by our WAV loader
|
|
augmenter=augmenter, # your existing augmenter
|
|
slide_frames=cfg["slide_frames"],
|
|
step_ms=10,
|
|
)
|
|
|
|
print(f" Generating files")
|
|
RaggedMmap.from_generator(
|
|
out_dir=str(out_dir / "wakeword_mmap"),
|
|
sample_generator=spectros.spectrogram_generator(
|
|
split=cfg["name"], repeat=cfg["repetition"]
|
|
),
|
|
batch_size=100,
|
|
verbose=False,
|
|
)
|
|
print(f" {split} augmentation complete")
|
|
|
|
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
|
et = END_TIME - START_TIME
|
|
print(f"\n{'=' * 80}")
|
|
msg=f"Augmented {max_samples} wake word samples."
|
|
print(f"{msg:>50s} Elapsed time: {et!s}")
|
|
print(f"{'=' * 80}\n")
|