mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
cli + web recorder ui
This commit is contained in:
@@ -71,17 +71,16 @@ if not files:
|
||||
max_samples = len(files)
|
||||
|
||||
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
||||
|
||||
print(" Initializing libraries")
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
|
||||
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"
|
||||
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
|
||||
os.environ["TF_XLA_FLAGS"]="--tf_xla_auto_jit=0"
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"]="1"
|
||||
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"]="512"
|
||||
os.environ["GLOG_minloglevel"]="9"
|
||||
os.environ["GRPC_VERBOSITY"]="ERROR"
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
|
||||
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
|
||||
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
|
||||
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"] = "512"
|
||||
os.environ["GLOG_minloglevel"] = "9"
|
||||
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
||||
|
||||
print(" Loading Tensorflow")
|
||||
import tensorflow as tf
|
||||
@@ -98,6 +97,7 @@ gc.collect()
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
from tqdm import tqdm
|
||||
from mmap_ninja.ragged import RaggedMmap
|
||||
from microwakeword.audio.augmentation import Augmentation
|
||||
from microwakeword.audio.clips import Clips
|
||||
@@ -108,7 +108,7 @@ START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
|
||||
# Paths to augmented data
|
||||
impulse_paths = [ args.mit_rirs_16k_dir ]
|
||||
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir]
|
||||
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir ]
|
||||
|
||||
clips = Clips(
|
||||
input_directory=args.input_dir,
|
||||
@@ -139,8 +139,6 @@ augmenter = Augmentation(
|
||||
max_jitter_s=0.3,
|
||||
)
|
||||
|
||||
# Augment samples and save the training, validation, and testing sets.
|
||||
|
||||
def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||
"""
|
||||
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
||||
@@ -175,7 +173,7 @@ def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||
# Bind the patched generator to your existing `clips` instance
|
||||
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
||||
|
||||
# ---- Split config (same as before) ----
|
||||
# ---- Split config ----
|
||||
split_cfg = {
|
||||
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
||||
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
||||
@@ -188,28 +186,34 @@ for split, cfg in split_cfg.items():
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f" Augmenting {split}")
|
||||
|
||||
print(f" Generating spectrograms")
|
||||
print(" Generating spectrograms")
|
||||
spectros = SpectrogramGeneration(
|
||||
clips=clips, # now backed by our WAV loader
|
||||
augmenter=augmenter, # your existing augmenter
|
||||
clips=clips,
|
||||
augmenter=augmenter,
|
||||
slide_frames=cfg["slide_frames"],
|
||||
step_ms=10,
|
||||
)
|
||||
|
||||
print(f" Generating files")
|
||||
print(" Generating files")
|
||||
print(" Sit tight — this step can take a while.")
|
||||
|
||||
gen = spectros.spectrogram_generator(
|
||||
split=cfg["name"],
|
||||
repeat=cfg["repetition"],
|
||||
)
|
||||
|
||||
RaggedMmap.from_generator(
|
||||
out_dir=str(out_dir / "wakeword_mmap"),
|
||||
sample_generator=spectros.spectrogram_generator(
|
||||
split=cfg["name"], repeat=cfg["repetition"]
|
||||
),
|
||||
sample_generator=gen,
|
||||
batch_size=100,
|
||||
verbose=False,
|
||||
verbose=False, # keep mmap quiet
|
||||
)
|
||||
|
||||
print(f" {split} augmentation complete")
|
||||
|
||||
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
et = END_TIME - START_TIME
|
||||
print(f"\n{'=' * 80}")
|
||||
msg=f"Augmented {max_samples} wake word samples."
|
||||
msg = f"Augmented {max_samples} wake word samples."
|
||||
print(f"{msg:>50s} Elapsed time: {et!s}")
|
||||
print(f"{'=' * 80}\n")
|
||||
print(f"{'=' * 80}\n")
|
||||
Reference in New Issue
Block a user