mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
personal samples
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import sys, os, gc, glob, random
|
import sys, os, gc, glob, random
|
||||||
import types, shutil, json
|
import types
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from argparse import ArgumentParser as ArgParser, ArgumentError
|
from argparse import ArgumentParser as ArgParser, ArgumentError
|
||||||
@@ -9,12 +9,20 @@ from argparse import ArgumentParser as ArgParser, ArgumentError
|
|||||||
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
|
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
|
||||||
|
|
||||||
parser = ArgParser(exit_on_error=False)
|
parser = ArgParser(exit_on_error=False)
|
||||||
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
||||||
parser.add_argument("--input-dir", type=str, help="Sample input directory. Default: <data-dir>/work/wake_word_samples", required=False)
|
|
||||||
parser.add_argument("--output-dir", type=str, help="Sample output directory. Default: <input-dir>_augmented", required=False)
|
# Wake word (TTS/generated) inputs/outputs
|
||||||
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
parser.add_argument("--input-dir", type=str, help="Wake word input dir. Default: <data-dir>/work/wake_word_samples", required=False)
|
||||||
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
parser.add_argument("--output-dir", type=str, help="Wake word output dir. Default: <input-dir>_augmented", required=False)
|
||||||
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
|
||||||
|
# Personal inputs/outputs (NEW)
|
||||||
|
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
|
||||||
|
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
|
||||||
|
|
||||||
|
# Dataset dirs
|
||||||
|
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
||||||
|
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
||||||
|
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -23,10 +31,11 @@ except ArgumentError:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
args.data_dir = os.path.realpath(args.data_dir)
|
args.data_dir = os.path.realpath(args.data_dir)
|
||||||
work_dir = args.data_dir + "/work"
|
work_dir = os.path.join(args.data_dir, "work")
|
||||||
|
|
||||||
|
# Wake word defaults
|
||||||
if not args.input_dir:
|
if not args.input_dir:
|
||||||
args.input_dir = work_dir + "/wake_word_samples"
|
args.input_dir = os.path.join(work_dir, "wake_word_samples")
|
||||||
else:
|
else:
|
||||||
args.input_dir = os.path.realpath(args.input_dir)
|
args.input_dir = os.path.realpath(args.input_dir)
|
||||||
|
|
||||||
@@ -35,24 +44,33 @@ if not args.output_dir:
|
|||||||
else:
|
else:
|
||||||
args.output_dir = os.path.realpath(args.output_dir)
|
args.output_dir = os.path.realpath(args.output_dir)
|
||||||
|
|
||||||
|
# Personal defaults (NEW)
|
||||||
|
if not args.personal_dir:
|
||||||
|
args.personal_dir = os.path.join(args.data_dir, "personal_samples")
|
||||||
|
else:
|
||||||
|
args.personal_dir = os.path.realpath(args.personal_dir)
|
||||||
|
|
||||||
|
if not args.personal_output_dir:
|
||||||
|
args.personal_output_dir = os.path.join(work_dir, "personal_augmented_features")
|
||||||
|
else:
|
||||||
|
args.personal_output_dir = os.path.realpath(args.personal_output_dir)
|
||||||
|
|
||||||
|
# Dataset defaults
|
||||||
if not args.mit_rirs_16k_dir:
|
if not args.mit_rirs_16k_dir:
|
||||||
args.mit_rirs_16k_dir = args.data_dir + "/training_datasets/mit_rirs_16k"
|
args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
|
||||||
else:
|
else:
|
||||||
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
|
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
|
||||||
|
|
||||||
if not args.fma_16k_dir:
|
if not args.fma_16k_dir:
|
||||||
args.fma_16k_dir = args.data_dir + "/training_datasets/fma_16k"
|
args.fma_16k_dir = os.path.join(args.data_dir, "training_datasets", "fma_16k")
|
||||||
else:
|
else:
|
||||||
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
|
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
|
||||||
|
|
||||||
if not args.audioset_16k_dir:
|
if not args.audioset_16k_dir:
|
||||||
args.audioset_16k_dir = args.data_dir + "/training_datasets/audioset_16k"
|
args.audioset_16k_dir = os.path.join(args.data_dir, "training_datasets", "audioset_16k")
|
||||||
else:
|
else:
|
||||||
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
|
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
|
||||||
|
|
||||||
out_path = Path(args.output_dir)
|
|
||||||
out_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
def validate_directories(paths):
|
def validate_directories(paths):
|
||||||
for path in paths:
|
for path in paths:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
@@ -60,17 +78,12 @@ def validate_directories(paths):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
paths = [ work_dir, args.input_dir, args.output_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir ]
|
required = [work_dir, args.input_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir]
|
||||||
if not validate_directories(paths):
|
if not validate_directories(required):
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
files = glob.glob(args.input_dir + "/*.wav")
|
# -------------------- TF + libs --------------------
|
||||||
if not files:
|
|
||||||
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
|
||||||
max_samples = len(files)
|
|
||||||
|
|
||||||
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
|
||||||
print(" Initializing libraries")
|
print(" Initializing libraries")
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
@@ -86,7 +99,6 @@ print(" Loading Tensorflow")
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
print(" GPU memory config")
|
print(" GPU memory config")
|
||||||
# Per-device memory growth (belt + suspenders)
|
|
||||||
for g in tf.config.list_physical_devices("GPU"):
|
for g in tf.config.list_physical_devices("GPU"):
|
||||||
try:
|
try:
|
||||||
tf.config.experimental.set_memory_growth(g, True)
|
tf.config.experimental.set_memory_growth(g, True)
|
||||||
@@ -97,27 +109,15 @@ gc.collect()
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import librosa
|
import librosa
|
||||||
from tqdm import tqdm
|
|
||||||
from mmap_ninja.ragged import RaggedMmap
|
from mmap_ninja.ragged import RaggedMmap
|
||||||
from microwakeword.audio.augmentation import Augmentation
|
from microwakeword.audio.augmentation import Augmentation
|
||||||
from microwakeword.audio.clips import Clips
|
from microwakeword.audio.clips import Clips
|
||||||
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
||||||
from microwakeword.audio.audio_utils import save_clip
|
|
||||||
|
|
||||||
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
|
||||||
# Paths to augmented data
|
impulse_paths = [args.mit_rirs_16k_dir]
|
||||||
impulse_paths = [ args.mit_rirs_16k_dir ]
|
background_paths = [args.fma_16k_dir, args.audioset_16k_dir]
|
||||||
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir ]
|
|
||||||
|
|
||||||
clips = Clips(
|
|
||||||
input_directory=args.input_dir,
|
|
||||||
file_pattern='*.wav',
|
|
||||||
max_clip_duration_s=5,
|
|
||||||
remove_silence=True,
|
|
||||||
random_split_seed=10,
|
|
||||||
split_count=0.1,
|
|
||||||
)
|
|
||||||
|
|
||||||
augmenter = Augmentation(
|
augmenter = Augmentation(
|
||||||
augmentation_duration_s=3.2,
|
augmentation_duration_s=3.2,
|
||||||
@@ -139,81 +139,107 @@ augmenter = Augmentation(
|
|||||||
max_jitter_s=0.3,
|
max_jitter_s=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
def audio_generator_from_wavs(self, split="train", repeat=1):
|
|
||||||
"""
|
|
||||||
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
|
||||||
Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
|
|
||||||
"""
|
|
||||||
files = sorted(glob.glob(args.input_dir + "/*.wav"))
|
|
||||||
if not files:
|
|
||||||
raise RuntimeError("❌ No WAVs in wake_word_samples.")
|
|
||||||
|
|
||||||
rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)
|
|
||||||
files_shuf = files[:]
|
|
||||||
rng.shuffle(files_shuf)
|
|
||||||
|
|
||||||
n = len(files_shuf)
|
|
||||||
n_val = max(1, int(0.10 * n))
|
|
||||||
n_test = max(1, int(0.10 * n))
|
|
||||||
n_train = max(0, n - n_val - n_test)
|
|
||||||
splits = {
|
|
||||||
"train": files_shuf[:n_train],
|
|
||||||
"validation": files_shuf[n_train:n_train + n_val],
|
|
||||||
"test": files_shuf[n_train + n_val:],
|
|
||||||
}
|
|
||||||
file_list = splits.get(split, [])
|
|
||||||
if not file_list:
|
|
||||||
return # nothing to yield
|
|
||||||
|
|
||||||
for _ in range(max(1, int(repeat))):
|
|
||||||
for p in file_list:
|
|
||||||
y, sr = librosa.load(p, sr=16000, mono=True)
|
|
||||||
yield y.astype(np.float32, copy=False)
|
|
||||||
|
|
||||||
# Bind the patched generator to your existing `clips` instance
|
|
||||||
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
|
||||||
|
|
||||||
# ---- Split config ----
|
|
||||||
split_cfg = {
|
split_cfg = {
|
||||||
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
||||||
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
||||||
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
|
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
# ---- Generate features ----
|
def bind_wav_generator(clips_obj: Clips, wav_dir: str):
|
||||||
for split, cfg in split_cfg.items():
|
"""
|
||||||
out_dir = out_path / split
|
Patch clips.audio_generator so we load WAVs directly (deterministic 80/10/10 split, seed=10).
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
Matches the notebook behavior you posted.
|
||||||
print(f" Augmenting {split}")
|
"""
|
||||||
|
def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||||
|
files = sorted(glob.glob(os.path.join(wav_dir, "*.wav")))
|
||||||
|
if not files:
|
||||||
|
return
|
||||||
|
|
||||||
print(" Generating spectrograms")
|
rng = random.Random(10)
|
||||||
spectros = SpectrogramGeneration(
|
files_shuf = files[:]
|
||||||
clips=clips,
|
rng.shuffle(files_shuf)
|
||||||
augmenter=augmenter,
|
|
||||||
slide_frames=cfg["slide_frames"],
|
n = len(files_shuf)
|
||||||
step_ms=10,
|
n_val = max(1, int(0.10 * n))
|
||||||
|
n_test = max(1, int(0.10 * n))
|
||||||
|
n_train = max(0, n - n_val - n_test)
|
||||||
|
|
||||||
|
splits = {
|
||||||
|
"train": files_shuf[:n_train],
|
||||||
|
"validation": files_shuf[n_train:n_train + n_val],
|
||||||
|
"test": files_shuf[n_train + n_val:],
|
||||||
|
}
|
||||||
|
file_list = splits.get(split, [])
|
||||||
|
if not file_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(max(1, int(repeat))):
|
||||||
|
for p in file_list:
|
||||||
|
y, _sr = librosa.load(p, sr=16000, mono=True)
|
||||||
|
yield y.astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)
|
||||||
|
|
||||||
|
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
|
||||||
|
files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
|
||||||
|
if not files:
|
||||||
|
print(f"ℹ️ No WAVs found for {label} in: {input_wav_dir} (skipping)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
max_samples = len(files)
|
||||||
|
print(f"\n===== Augmenting {max_samples} wake word samples ({label}) =====")
|
||||||
|
|
||||||
|
clips = Clips(
|
||||||
|
input_directory=input_wav_dir,
|
||||||
|
file_pattern="*.wav",
|
||||||
|
max_clip_duration_s=5,
|
||||||
|
remove_silence=True,
|
||||||
|
random_split_seed=10,
|
||||||
|
split_count=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(" Generating files")
|
bind_wav_generator(clips, input_wav_dir)
|
||||||
print(" Sit tight — this step can take a while.")
|
|
||||||
|
|
||||||
gen = spectros.spectrogram_generator(
|
out_root = Path(out_root_dir)
|
||||||
split=cfg["name"],
|
out_root.mkdir(parents=True, exist_ok=True)
|
||||||
repeat=cfg["repetition"],
|
|
||||||
)
|
|
||||||
|
|
||||||
RaggedMmap.from_generator(
|
for split, cfg in split_cfg.items():
|
||||||
out_dir=str(out_dir / "wakeword_mmap"),
|
out_dir = out_root / split
|
||||||
sample_generator=gen,
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
batch_size=100,
|
print(f" Augmenting {split} ({label})")
|
||||||
verbose=False, # keep mmap quiet
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" {split} augmentation complete")
|
spectros = SpectrogramGeneration(
|
||||||
|
clips=clips,
|
||||||
|
augmenter=augmenter,
|
||||||
|
slide_frames=cfg["slide_frames"],
|
||||||
|
step_ms=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
gen = spectros.spectrogram_generator(
|
||||||
|
split=cfg["name"],
|
||||||
|
repeat=cfg["repetition"],
|
||||||
|
)
|
||||||
|
|
||||||
|
RaggedMmap.from_generator(
|
||||||
|
out_dir=str(out_dir / "wakeword_mmap"),
|
||||||
|
sample_generator=gen,
|
||||||
|
batch_size=100,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" {split} augmentation complete ({label})")
|
||||||
|
|
||||||
|
print(f"✅ Features ready: {out_root_dir}/*/wakeword_mmap")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Wake word generated/TTS features (existing behavior)
|
||||||
|
generate_feature_set(args.input_dir, args.output_dir, "generated")
|
||||||
|
|
||||||
|
# Personal features (NEW)
|
||||||
|
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")
|
||||||
|
|
||||||
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
et = END_TIME - START_TIME
|
et = END_TIME - START_TIME
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
msg = f"Augmented {max_samples} wake word samples."
|
print(f"{'Augmentation completed.':>50s} Elapsed time: {et!s}")
|
||||||
print(f"{msg:>50s} Elapsed time: {et!s}")
|
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
@@ -60,43 +60,57 @@ check_directories() {
|
|||||||
check_directories ${WORK_DIR}/wake_word_samples_augmented \
|
check_directories ${WORK_DIR}/wake_word_samples_augmented \
|
||||||
${TRAINING_DS}/negative_datasets/{speech,dinner_party,no_speech,dinner_party_eval}
|
${TRAINING_DS}/negative_datasets/{speech,dinner_party,no_speech,dinner_party_eval}
|
||||||
|
|
||||||
|
# Personal features are optional, but if present they MUST have /training
|
||||||
|
PERSONAL_FEATURES_DIR="${WORK_DIR}/personal_augmented_features"
|
||||||
|
HAS_PERSONAL="false"
|
||||||
|
if [ -d "${PERSONAL_FEATURES_DIR}/training" ] ; then
|
||||||
|
HAS_PERSONAL="true"
|
||||||
|
echo "✅ Found personal features: ${PERSONAL_FEATURES_DIR}/training (will weight sampling_weight=3.0)"
|
||||||
|
else
|
||||||
|
echo "ℹ️ No personal features found at ${PERSONAL_FEATURES_DIR}/training (continuing without personal weighting)"
|
||||||
|
fi
|
||||||
|
|
||||||
cd "${WORK_DIR}"
|
cd "${WORK_DIR}"
|
||||||
|
|
||||||
echo "===== Starting ${TRAINING_STEPS} training steps ====="
|
echo "===== Starting ${TRAINING_STEPS} training steps ====="
|
||||||
|
|
||||||
START_TS=$EPOCHSECONDS
|
START_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
mkdir -p "${WORK_DIR}/trained_models" || :
|
mkdir -p "${WORK_DIR}/trained_models" || :
|
||||||
cat <<EOF >"${WORK_DIR}/trained_models/training_parameters.yaml"
|
|
||||||
|
# We write a YAML with a marker, then splice personal feature block in if it exists.
|
||||||
|
YAML_PATH="${WORK_DIR}/trained_models/training_parameters.yaml"
|
||||||
|
|
||||||
|
cat <<'EOF' > "${YAML_PATH}"
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
clip_duration_ms: 1500
|
clip_duration_ms: 1500
|
||||||
eval_step_interval: 500
|
eval_step_interval: 500
|
||||||
features:
|
features:
|
||||||
- features_dir: ${WORK_DIR}/wake_word_samples_augmented
|
- features_dir: __WAKEWORD_FEATURES__
|
||||||
penalty_weight: 1.0
|
penalty_weight: 1.0
|
||||||
sampling_weight: 2.0
|
sampling_weight: 2.0
|
||||||
truncation_strategy: truncate_start
|
truncation_strategy: truncate_start
|
||||||
truth: true
|
truth: true
|
||||||
type: mmap
|
type: mmap
|
||||||
- features_dir: ${TRAINING_DS}/negative_datasets/speech
|
__PERSONAL_FEATURE_MARKER__
|
||||||
|
- features_dir: __NEG_SPEECH__
|
||||||
penalty_weight: 1.0
|
penalty_weight: 1.0
|
||||||
sampling_weight: 12.0
|
sampling_weight: 12.0
|
||||||
truncation_strategy: random
|
truncation_strategy: random
|
||||||
truth: false
|
truth: false
|
||||||
type: mmap
|
type: mmap
|
||||||
- features_dir: ${TRAINING_DS}/negative_datasets/dinner_party
|
- features_dir: __NEG_DINNER__
|
||||||
penalty_weight: 1.0
|
penalty_weight: 1.0
|
||||||
sampling_weight: 12.0
|
sampling_weight: 12.0
|
||||||
truncation_strategy: random
|
truncation_strategy: random
|
||||||
truth: false
|
truth: false
|
||||||
type: mmap
|
type: mmap
|
||||||
- features_dir: ${TRAINING_DS}/negative_datasets/no_speech
|
- features_dir: __NEG_NOSPEECH__
|
||||||
penalty_weight: 1.0
|
penalty_weight: 1.0
|
||||||
sampling_weight: 5.0
|
sampling_weight: 5.0
|
||||||
truncation_strategy: random
|
truncation_strategy: random
|
||||||
truth: false
|
truth: false
|
||||||
type: mmap
|
type: mmap
|
||||||
- features_dir: ${TRAINING_DS}/negative_datasets/dinner_party_eval
|
- features_dir: __NEG_DINNER_EVAL__
|
||||||
penalty_weight: 1.0
|
penalty_weight: 1.0
|
||||||
sampling_weight: 0.0
|
sampling_weight: 0.0
|
||||||
truncation_strategy: split
|
truncation_strategy: split
|
||||||
@@ -119,25 +133,46 @@ time_mask_count:
|
|||||||
- 0
|
- 0
|
||||||
time_mask_max_size:
|
time_mask_max_size:
|
||||||
- 0
|
- 0
|
||||||
train_dir: ${WORK_DIR}/trained_models/wakeword
|
train_dir: __TRAIN_DIR__
|
||||||
training_steps:
|
training_steps:
|
||||||
- ${TRAINING_STEPS}
|
- __TRAINING_STEPS__
|
||||||
window_step_ms: 10
|
window_step_ms: 10
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
# Replace placeholders (portable)
|
||||||
|
sed -i \
|
||||||
|
-e "s|__WAKEWORD_FEATURES__|${WORK_DIR}/wake_word_samples_augmented|g" \
|
||||||
|
-e "s|__NEG_SPEECH__|${TRAINING_DS}/negative_datasets/speech|g" \
|
||||||
|
-e "s|__NEG_DINNER__|${TRAINING_DS}/negative_datasets/dinner_party|g" \
|
||||||
|
-e "s|__NEG_NOSPEECH__|${TRAINING_DS}/negative_datasets/no_speech|g" \
|
||||||
|
-e "s|__NEG_DINNER_EVAL__|${TRAINING_DS}/negative_datasets/dinner_party_eval|g" \
|
||||||
|
-e "s|__TRAIN_DIR__|${WORK_DIR}/trained_models/wakeword|g" \
|
||||||
|
-e "s|__TRAINING_STEPS__|${TRAINING_STEPS}|g" \
|
||||||
|
"${YAML_PATH}"
|
||||||
|
|
||||||
|
# Insert/remove personal block
|
||||||
|
if [ "${HAS_PERSONAL}" = "true" ]; then
|
||||||
|
# Insert directly after the wakeword feature block (matches your notebook: insert(1, ...))
|
||||||
|
perl -0777 -i -pe 's/__PERSONAL_FEATURE_MARKER__/\n- features_dir: '"${PERSONAL_FEATURES_DIR}"'\n penalty_weight: 1.0\n sampling_weight: 3.0\n truncation_strategy: truncate_start\n truth: true\n type: mmap\n/g' "${YAML_PATH}"
|
||||||
|
else
|
||||||
|
# Remove marker line entirely
|
||||||
|
sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}"
|
||||||
|
fi
|
||||||
|
|
||||||
echo " Wrote training_parameters.yaml"
|
echo " Wrote training_parameters.yaml"
|
||||||
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
||||||
|
|
||||||
wake_word_filename="${WAKE_WORD//[ \`~\!\$&*$begin:math:text$$end:math:text$\{\}$begin:math:display$$end:math:display$\|\;\'\"<>.?\/]/_}"
|
wake_word_filename="$(
|
||||||
|
echo "${WAKE_WORD}" \
|
||||||
|
| tr '[:upper:]' '[:lower:]' \
|
||||||
|
| sed -E 's/[^a-z0-9]+/_/g; s/^_+//; s/_+$//'
|
||||||
|
)"
|
||||||
|
[ -n "${wake_word_filename}" ] || wake_word_filename="wakeword"
|
||||||
|
|
||||||
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
|
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
|
||||||
mkdir -p "${OUTPUT_DIR}/logs" || :
|
mkdir -p "${OUTPUT_DIR}/logs" || :
|
||||||
|
|
||||||
TRAIN_LOG="${OUTPUT_DIR}/logs/training.log"
|
TRAIN_LOG="${OUTPUT_DIR}/logs/training.log"
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Training args (same as before)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
TRAIN_ARGS=(
|
TRAIN_ARGS=(
|
||||||
-m microwakeword.model_train_eval
|
-m microwakeword.model_train_eval
|
||||||
--training_config "${WORK_DIR}/trained_models/training_parameters.yaml"
|
--training_config "${WORK_DIR}/trained_models/training_parameters.yaml"
|
||||||
@@ -159,10 +194,6 @@ TRAIN_ARGS=(
|
|||||||
--stride 2
|
--stride 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# GPU failure markers that should trigger CPU fallback
|
|
||||||
# (OOM + known GPU runtime/copy/init failures)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
GPU_FALLBACK_MARKERS=(
|
GPU_FALLBACK_MARKERS=(
|
||||||
"resourceexhaustederror"
|
"resourceexhaustederror"
|
||||||
"resource exhausted"
|
"resource exhausted"
|
||||||
@@ -189,7 +220,6 @@ run_attempt() {
|
|||||||
echo "→ ${PYTHON_BIN:-python} ${TRAIN_ARGS[*]}"
|
echo "→ ${PYTHON_BIN:-python} ${TRAIN_ARGS[*]}"
|
||||||
echo
|
echo
|
||||||
|
|
||||||
# stream everything except validation minibatch spam
|
|
||||||
"${PYTHON_BIN:-python}" "${TRAIN_ARGS[@]}" 2>&1 \
|
"${PYTHON_BIN:-python}" "${TRAIN_ARGS[@]}" 2>&1 \
|
||||||
| tr '\r' '\n' \
|
| tr '\r' '\n' \
|
||||||
| stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" \
|
| stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" \
|
||||||
@@ -199,20 +229,17 @@ run_attempt() {
|
|||||||
return ${PIPESTATUS[0]}
|
return ${PIPESTATUS[0]}
|
||||||
}
|
}
|
||||||
|
|
||||||
# ---- Common TF env (mirrors your notebook) ----
|
|
||||||
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"
|
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"
|
||||||
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
|
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
|
||||||
export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}"
|
export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}"
|
||||||
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
|
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
|
||||||
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
|
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
|
||||||
|
|
||||||
# Attempt 1: GPU
|
|
||||||
if run_attempt "Attempt 1/2: GPU training (allow_growth + cuda_malloc_async)" ; then
|
if run_attempt "Attempt 1/2: GPU training (allow_growth + cuda_malloc_async)" ; then
|
||||||
echo "✅ Training complete (GPU path)."
|
echo "✅ Training complete (GPU path)."
|
||||||
else
|
else
|
||||||
echo "⚠️ GPU attempt failed. Checking whether this looks like a GPU/OOM/runtime failure…"
|
echo "⚠️ GPU attempt failed. Checking whether this looks like a GPU/OOM/runtime failure…"
|
||||||
|
|
||||||
# Check log for GPU/OOM/runtime markers
|
|
||||||
log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)"
|
log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)"
|
||||||
looks_like_gpu_fail="false"
|
looks_like_gpu_fail="false"
|
||||||
for m in "${GPU_FALLBACK_MARKERS[@]}"; do
|
for m in "${GPU_FALLBACK_MARKERS[@]}"; do
|
||||||
@@ -225,7 +252,6 @@ else
|
|||||||
if [ "${looks_like_gpu_fail}" = "true" ]; then
|
if [ "${looks_like_gpu_fail}" = "true" ]; then
|
||||||
echo "↪️ Detected GPU/OOM/runtime failure markers. Falling back to CPU."
|
echo "↪️ Detected GPU/OOM/runtime failure markers. Falling back to CPU."
|
||||||
|
|
||||||
# Attempt 2: CPU (hide GPU completely)
|
|
||||||
export CUDA_VISIBLE_DEVICES=""
|
export CUDA_VISIBLE_DEVICES=""
|
||||||
unset TF_GPU_ALLOCATOR
|
unset TF_GPU_ALLOCATOR
|
||||||
if run_attempt "Attempt 2/2: CPU fallback (CUDA_VISIBLE_DEVICES='')" ; then
|
if run_attempt "Attempt 2/2: CPU fallback (CUDA_VISIBLE_DEVICES='')" ; then
|
||||||
@@ -256,7 +282,6 @@ echo " Full log: ${TRAIN_LOG}"
|
|||||||
|
|
||||||
tflite_filename="${wake_word_filename}.tflite"
|
tflite_filename="${wake_word_filename}.tflite"
|
||||||
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||||
|
|
||||||
cp "${source_path}" "${tflite_path}"
|
cp "${source_path}" "${tflite_path}"
|
||||||
|
|
||||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||||
|
|||||||
Reference in New Issue
Block a user