mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
personal samples
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys, os, gc, glob, random
|
||||
import types, shutil, json
|
||||
import types
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser as ArgParser, ArgumentError
|
||||
@@ -9,12 +9,20 @@ from argparse import ArgumentParser as ArgParser, ArgumentError
|
||||
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
|
||||
|
||||
parser = ArgParser(exit_on_error=False)
|
||||
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
||||
parser.add_argument("--input-dir", type=str, help="Sample input directory. Default: <data-dir>/work/wake_word_samples", required=False)
|
||||
parser.add_argument("--output-dir", type=str, help="Sample output directory. Default: <input-dir>_augmented", required=False)
|
||||
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
||||
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
||||
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
||||
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
||||
|
||||
# Wake word (TTS/generated) inputs/outputs
|
||||
parser.add_argument("--input-dir", type=str, help="Wake word input dir. Default: <data-dir>/work/wake_word_samples", required=False)
|
||||
parser.add_argument("--output-dir", type=str, help="Wake word output dir. Default: <input-dir>_augmented", required=False)
|
||||
|
||||
# Personal inputs/outputs (NEW)
|
||||
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
|
||||
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
|
||||
|
||||
# Dataset dirs
|
||||
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
||||
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
||||
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
@@ -23,10 +31,11 @@ except ArgumentError:
|
||||
sys.exit(1)
|
||||
|
||||
args.data_dir = os.path.realpath(args.data_dir)
|
||||
work_dir = args.data_dir + "/work"
|
||||
work_dir = os.path.join(args.data_dir, "work")
|
||||
|
||||
# Wake word defaults
|
||||
if not args.input_dir:
|
||||
args.input_dir = work_dir + "/wake_word_samples"
|
||||
args.input_dir = os.path.join(work_dir, "wake_word_samples")
|
||||
else:
|
||||
args.input_dir = os.path.realpath(args.input_dir)
|
||||
|
||||
@@ -35,24 +44,33 @@ if not args.output_dir:
|
||||
else:
|
||||
args.output_dir = os.path.realpath(args.output_dir)
|
||||
|
||||
# Personal defaults (NEW)
|
||||
if not args.personal_dir:
|
||||
args.personal_dir = os.path.join(args.data_dir, "personal_samples")
|
||||
else:
|
||||
args.personal_dir = os.path.realpath(args.personal_dir)
|
||||
|
||||
if not args.personal_output_dir:
|
||||
args.personal_output_dir = os.path.join(work_dir, "personal_augmented_features")
|
||||
else:
|
||||
args.personal_output_dir = os.path.realpath(args.personal_output_dir)
|
||||
|
||||
# Dataset defaults
|
||||
if not args.mit_rirs_16k_dir:
|
||||
args.mit_rirs_16k_dir = args.data_dir + "/training_datasets/mit_rirs_16k"
|
||||
args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
|
||||
else:
|
||||
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
|
||||
|
||||
if not args.fma_16k_dir:
|
||||
args.fma_16k_dir = args.data_dir + "/training_datasets/fma_16k"
|
||||
args.fma_16k_dir = os.path.join(args.data_dir, "training_datasets", "fma_16k")
|
||||
else:
|
||||
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
|
||||
|
||||
if not args.audioset_16k_dir:
|
||||
args.audioset_16k_dir = args.data_dir + "/training_datasets/audioset_16k"
|
||||
args.audioset_16k_dir = os.path.join(args.data_dir, "training_datasets", "audioset_16k")
|
||||
else:
|
||||
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
|
||||
|
||||
out_path = Path(args.output_dir)
|
||||
out_path.mkdir(exist_ok=True)
|
||||
|
||||
def validate_directories(paths):
|
||||
for path in paths:
|
||||
if not os.path.exists(path):
|
||||
@@ -60,17 +78,12 @@ def validate_directories(paths):
|
||||
return False
|
||||
return True
|
||||
|
||||
paths = [ work_dir, args.input_dir, args.output_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir ]
|
||||
if not validate_directories(paths):
|
||||
required = [work_dir, args.input_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir]
|
||||
if not validate_directories(required):
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
files = glob.glob(args.input_dir + "/*.wav")
|
||||
if not files:
|
||||
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
||||
max_samples = len(files)
|
||||
|
||||
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
||||
# -------------------- TF + libs --------------------
|
||||
print(" Initializing libraries")
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@@ -86,7 +99,6 @@ print(" Loading Tensorflow")
|
||||
import tensorflow as tf
|
||||
|
||||
print(" GPU memory config")
|
||||
# Per-device memory growth (belt + suspenders)
|
||||
for g in tf.config.list_physical_devices("GPU"):
|
||||
try:
|
||||
tf.config.experimental.set_memory_growth(g, True)
|
||||
@@ -97,27 +109,15 @@ gc.collect()
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
from tqdm import tqdm
|
||||
from mmap_ninja.ragged import RaggedMmap
|
||||
from microwakeword.audio.augmentation import Augmentation
|
||||
from microwakeword.audio.clips import Clips
|
||||
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
||||
from microwakeword.audio.audio_utils import save_clip
|
||||
|
||||
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
|
||||
# Paths to augmented data
|
||||
impulse_paths = [ args.mit_rirs_16k_dir ]
|
||||
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir ]
|
||||
|
||||
clips = Clips(
|
||||
input_directory=args.input_dir,
|
||||
file_pattern='*.wav',
|
||||
max_clip_duration_s=5,
|
||||
remove_silence=True,
|
||||
random_split_seed=10,
|
||||
split_count=0.1,
|
||||
)
|
||||
impulse_paths = [args.mit_rirs_16k_dir]
|
||||
background_paths = [args.fma_16k_dir, args.audioset_16k_dir]
|
||||
|
||||
augmenter = Augmentation(
|
||||
augmentation_duration_s=3.2,
|
||||
@@ -139,81 +139,107 @@ augmenter = Augmentation(
|
||||
max_jitter_s=0.3,
|
||||
)
|
||||
|
||||
def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||
"""
|
||||
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
||||
Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
|
||||
"""
|
||||
files = sorted(glob.glob(args.input_dir + "/*.wav"))
|
||||
if not files:
|
||||
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
||||
|
||||
rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)
|
||||
files_shuf = files[:]
|
||||
rng.shuffle(files_shuf)
|
||||
|
||||
n = len(files_shuf)
|
||||
n_val = max(1, int(0.10 * n))
|
||||
n_test = max(1, int(0.10 * n))
|
||||
n_train = max(0, n - n_val - n_test)
|
||||
splits = {
|
||||
"train": files_shuf[:n_train],
|
||||
"validation": files_shuf[n_train:n_train + n_val],
|
||||
"test": files_shuf[n_train + n_val:],
|
||||
}
|
||||
file_list = splits.get(split, [])
|
||||
if not file_list:
|
||||
return # nothing to yield
|
||||
|
||||
for _ in range(max(1, int(repeat))):
|
||||
for p in file_list:
|
||||
y, sr = librosa.load(p, sr=16000, mono=True)
|
||||
yield y.astype(np.float32, copy=False)
|
||||
|
||||
# Bind the patched generator to your existing `clips` instance
|
||||
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
||||
|
||||
# ---- Split config ----
|
||||
split_cfg = {
|
||||
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
||||
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
||||
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
|
||||
}
|
||||
|
||||
# ---- Generate features ----
|
||||
for split, cfg in split_cfg.items():
|
||||
out_dir = out_path / split
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f" Augmenting {split}")
|
||||
def bind_wav_generator(clips_obj: Clips, wav_dir: str):
|
||||
"""
|
||||
Patch clips.audio_generator so we load WAVs directly (deterministic 80/10/10 split, seed=10).
|
||||
Matches the notebook behavior you posted.
|
||||
"""
|
||||
def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||
files = sorted(glob.glob(os.path.join(wav_dir, "*.wav")))
|
||||
if not files:
|
||||
return
|
||||
|
||||
print(" Generating spectrograms")
|
||||
spectros = SpectrogramGeneration(
|
||||
clips=clips,
|
||||
augmenter=augmenter,
|
||||
slide_frames=cfg["slide_frames"],
|
||||
step_ms=10,
|
||||
rng = random.Random(10)
|
||||
files_shuf = files[:]
|
||||
rng.shuffle(files_shuf)
|
||||
|
||||
n = len(files_shuf)
|
||||
n_val = max(1, int(0.10 * n))
|
||||
n_test = max(1, int(0.10 * n))
|
||||
n_train = max(0, n - n_val - n_test)
|
||||
|
||||
splits = {
|
||||
"train": files_shuf[:n_train],
|
||||
"validation": files_shuf[n_train:n_train + n_val],
|
||||
"test": files_shuf[n_train + n_val:],
|
||||
}
|
||||
file_list = splits.get(split, [])
|
||||
if not file_list:
|
||||
return
|
||||
|
||||
for _ in range(max(1, int(repeat))):
|
||||
for p in file_list:
|
||||
y, _sr = librosa.load(p, sr=16000, mono=True)
|
||||
yield y.astype(np.float32, copy=False)
|
||||
|
||||
clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)
|
||||
|
||||
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
|
||||
files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
|
||||
if not files:
|
||||
print(f"ℹ️ No WAVs found for {label} in: {input_wav_dir} (skipping)")
|
||||
return False
|
||||
|
||||
max_samples = len(files)
|
||||
print(f"\n===== Augmenting {max_samples} wake word samples ({label}) =====")
|
||||
|
||||
clips = Clips(
|
||||
input_directory=input_wav_dir,
|
||||
file_pattern="*.wav",
|
||||
max_clip_duration_s=5,
|
||||
remove_silence=True,
|
||||
random_split_seed=10,
|
||||
split_count=0.1,
|
||||
)
|
||||
|
||||
print(" Generating files")
|
||||
print(" Sit tight — this step can take a while.")
|
||||
bind_wav_generator(clips, input_wav_dir)
|
||||
|
||||
gen = spectros.spectrogram_generator(
|
||||
split=cfg["name"],
|
||||
repeat=cfg["repetition"],
|
||||
)
|
||||
out_root = Path(out_root_dir)
|
||||
out_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
RaggedMmap.from_generator(
|
||||
out_dir=str(out_dir / "wakeword_mmap"),
|
||||
sample_generator=gen,
|
||||
batch_size=100,
|
||||
verbose=False, # keep mmap quiet
|
||||
)
|
||||
for split, cfg in split_cfg.items():
|
||||
out_dir = out_root / split
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f" Augmenting {split} ({label})")
|
||||
|
||||
print(f" {split} augmentation complete")
|
||||
spectros = SpectrogramGeneration(
|
||||
clips=clips,
|
||||
augmenter=augmenter,
|
||||
slide_frames=cfg["slide_frames"],
|
||||
step_ms=10,
|
||||
)
|
||||
|
||||
gen = spectros.spectrogram_generator(
|
||||
split=cfg["name"],
|
||||
repeat=cfg["repetition"],
|
||||
)
|
||||
|
||||
RaggedMmap.from_generator(
|
||||
out_dir=str(out_dir / "wakeword_mmap"),
|
||||
sample_generator=gen,
|
||||
batch_size=100,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
print(f" {split} augmentation complete ({label})")
|
||||
|
||||
print(f"✅ Features ready: {out_root_dir}/*/wakeword_mmap")
|
||||
return True
|
||||
|
||||
# Wake word generated/TTS features (existing behavior)
|
||||
generate_feature_set(args.input_dir, args.output_dir, "generated")
|
||||
|
||||
# Personal features (NEW)
|
||||
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")
|
||||
|
||||
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
et = END_TIME - START_TIME
|
||||
print(f"\n{'=' * 80}")
|
||||
msg = f"Augmented {max_samples} wake word samples."
|
||||
print(f"{msg:>50s} Elapsed time: {et!s}")
|
||||
print(f"{'Augmentation completed.':>50s} Elapsed time: {et!s}")
|
||||
print(f"{'=' * 80}\n")
|
||||
Reference in New Issue
Block a user