diff --git a/README.md b/README.md index db3472e..5cc8344 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ That will remove: - browser microphone recording has been removed - personal samples are optional - the server module is now `trainer_server.py` -- the launcher script is still named `run_recorder.sh` for compatibility +- the launcher script is now `run.sh` --- diff --git a/__pycache__/trainer_server.cpython-311.pyc b/__pycache__/trainer_server.cpython-311.pyc new file mode 100644 index 0000000..3893d6f Binary files /dev/null and b/__pycache__/trainer_server.cpython-311.pyc differ diff --git a/cli/calibrate_detector.py b/cli/calibrate_detector.py new file mode 100644 index 0000000..5cd56da --- /dev/null +++ b/cli/calibrate_detector.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python3 +"""Choose detector metadata that better matches the trained model.""" + +from __future__ import annotations + +import argparse +import json +import math +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Iterable, Sequence + +import numpy as np +import yaml + +from microwakeword.data import FeatureHandler +from microwakeword.inference import Model + + +DEFAULT_WINDOW_SIZES = [3, 4, 5, 6, 7] +DEFAULT_TARGET_FAPH = float(os.environ.get("MWW_CALIBRATION_TARGET_FAPH", "1.0")) +DEFAULT_COOLDOWN_SLICES = int(os.environ.get("MWW_CALIBRATION_COOLDOWN_SLICES", "25")) +DEFAULT_POSITIVE_SKIP_SLICES = int( + os.environ.get("MWW_CALIBRATION_POSITIVE_SKIP_SLICES", "25") +) +DEFAULT_CUTOFF_STEP = float(os.environ.get("MWW_CALIBRATION_CUTOFF_STEP", "0.01")) +DEFAULT_CUTOFF_MIN = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MIN", "0.00")) +DEFAULT_CUTOFF_MAX = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MAX", "1.00")) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Calibrate microWakeWord detector metadata from validation data." + ) + parser.add_argument( + "--training-config", + default="trained_models/wakeword/training_config.yaml", + help="Path to the saved microWakeWord training_config.yaml file.", + ) + parser.add_argument( + "--model", + default=( + "trained_models/wakeword/tflite_stream_state_internal_quant/" + "stream_state_internal_quant.tflite" + ), + help="Path to the quantized streaming TFLite model.", + ) + parser.add_argument( + "--output", + default=( + "trained_models/wakeword/tflite_stream_state_internal_quant/" + "detection_calibration.json" + ), + help="Where to write the selected detector settings as JSON.", + ) + parser.add_argument( + "--window-sizes", + default=",".join(str(value) for value in DEFAULT_WINDOW_SIZES), + help="Comma-separated sliding window sizes to evaluate.", + ) + parser.add_argument( + "--target-faph", + type=float, + default=DEFAULT_TARGET_FAPH, + help="Target ambient false accepts per hour for the selected operating point.", + ) + parser.add_argument( + "--cooldown-slices", + type=int, + default=DEFAULT_COOLDOWN_SLICES, + help="Cooldown slices to use when estimating false accepts per hour.", + ) + parser.add_argument( + "--positive-skip-slices", + type=int, + default=DEFAULT_POSITIVE_SKIP_SLICES, + help="Initial streaming slices to ignore when scoring positive examples.", + ) + parser.add_argument( + "--cutoff-step", + type=float, + default=DEFAULT_CUTOFF_STEP, + help="Cutoff increment to evaluate between cutoff-min and cutoff-max.", + ) + parser.add_argument( + "--cutoff-min", + type=float, + default=DEFAULT_CUTOFF_MIN, + help="Minimum cutoff to evaluate.", + ) + parser.add_argument( + "--cutoff-max", + type=float, + default=DEFAULT_CUTOFF_MAX, + help="Maximum cutoff to evaluate.", + ) + return parser.parse_args() + + +def _parse_window_sizes(raw: str) -> list[int]: + values = [] + for item in (raw or "").split(","): + item = item.strip() + if not item: + continue + value = int(item) + if value < 1: + raise ValueError("window sizes must be >= 1") + values.append(value) + if not values: + raise ValueError("at least one window size is required") + return sorted(set(values)) + + +def _moving_average(values: Sequence[float], window_size: int) -> np.ndarray: + array = np.asarray(values, dtype=np.float32) + if array.size == 0: + return array + if window_size <= 1: + return array + if array.size < window_size: + return np.asarray([float(array.mean())], dtype=np.float32) + cumsum = np.cumsum(np.insert(array, 0, 0.0)) + averaged = (cumsum[window_size:] - cumsum[:-window_size]) / float(window_size) + return averaged.astype(np.float32) + + +def _compute_false_accepts_per_hour( + probabilities_per_track: Iterable[np.ndarray], + cutoffs: np.ndarray, + cooldown_slices: int, + stride: int, + step_seconds: float, +) -> tuple[np.ndarray, float]: + cutoffs = np.asarray(cutoffs, dtype=np.float32) + false_accepts = np.zeros(cutoffs.shape[0], dtype=np.float64) + duration_hours = 0.0 + + for track_probabilities in probabilities_per_track: + if track_probabilities.size == 0: + continue + duration_hours += ( + len(track_probabilities) * stride * step_seconds / 3600.0 + ) + cooldown = np.full(cutoffs.shape[0], cooldown_slices, dtype=np.int32) + for probability in track_probabilities: + cooldown = np.maximum(cooldown - 1, 0) + accepted = (cooldown == 0) & (probability > cutoffs) + false_accepts += accepted.astype(np.float64) + cooldown = np.where(accepted, cooldown_slices, cooldown) + + if duration_hours <= 0: + return np.full(cutoffs.shape[0], math.inf, dtype=np.float64), 0.0 + + return false_accepts / duration_hours, duration_hours + + +def _select_best_candidate( + candidates: list[dict[str, float]], + target_faph: float, +) -> tuple[dict[str, float], float]: + fallback_limits = [ + target_faph, + max(target_faph * 2.0, target_faph + 0.5), + max(target_faph * 4.0, 2.0), + ] + + def tier(candidate: dict[str, float]) -> int: + for index, limit in enumerate(fallback_limits): + if candidate["false_accepts_per_hour"] <= limit + 1e-9: + return index + return len(fallback_limits) + + best = min( + candidates, + key=lambda candidate: ( + tier(candidate), + -candidate["recall"], + candidate["false_accepts_per_hour"], + abs(candidate["sliding_window_size"] - 5), + -candidate["probability_cutoff"], + ), + ) + + tier_index = tier(best) + if tier_index < len(fallback_limits): + return best, fallback_limits[tier_index] + return best, float("inf") + + +def _load_config(config_path: Path) -> dict: + with config_path.open("r", encoding="utf-8") as handle: + return yaml.load(handle.read(), Loader=yaml.Loader) + + +def _load_eval_sets( + handler: FeatureHandler, + config: dict, +) -> tuple[str, str, list[np.ndarray], list[np.ndarray]]: + for positive_mode, ambient_mode in ( + ("validation", "validation_ambient"), + ("testing", "testing_ambient"), + ): + positive_tracks, labels, _ = handler.get_data( + positive_mode, + batch_size=config["batch_size"], + features_length=config["spectrogram_length"], + truncation_strategy="none", + ) + ambient_tracks, _, _ = handler.get_data( + ambient_mode, + batch_size=config["batch_size"], + features_length=config["spectrogram_length"], + truncation_strategy="none", + ) + positives = [ + np.asarray(track) + for track, label in zip(positive_tracks, labels) + if bool(label) + ] + ambient = [np.asarray(track) for track in ambient_tracks] + if positives and ambient: + return positive_mode, ambient_mode, positives, ambient + raise RuntimeError( + "No suitable validation/testing data was found for detector calibration." + ) + + +def _predict_tracks( + model: Model, + tracks: Sequence[np.ndarray], + label: str, +) -> list[np.ndarray]: + predictions: list[np.ndarray] = [] + total = len(tracks) + print(f"→ Running streaming inference on {total} {label} track(s)") + for index, track in enumerate(tracks, start=1): + values = np.asarray(model.predict_spectrogram(track), dtype=np.float32) + predictions.append(values) + if index == total or index % 25 == 0: + print(f" {label}: {index}/{total}") + return predictions + + +def main() -> int: + args = parse_args() + window_sizes = _parse_window_sizes(args.window_sizes) + if args.cutoff_step <= 0: + raise ValueError("cutoff-step must be > 0") + if args.cutoff_max < args.cutoff_min: + raise ValueError("cutoff-max must be >= cutoff-min") + + config_path = Path(args.training_config) + model_path = Path(args.model) + output_path = Path(args.output) + + if not config_path.exists(): + raise FileNotFoundError(f"Training config not found: {config_path}") + if not model_path.exists(): + raise FileNotFoundError(f"Streaming TFLite model not found: {model_path}") + + cutoffs = np.arange( + args.cutoff_min, + args.cutoff_max + (args.cutoff_step / 2.0), + args.cutoff_step, + dtype=np.float32, + ) + cutoffs = np.clip(cutoffs, 0.0, 1.0) + cutoffs = np.unique(np.round(cutoffs, 4)) + + print("===== Detector Calibration =====") + print(f"→ Model: {model_path}") + print(f"→ Training config: {config_path}") + print( + f"→ Evaluating window sizes {window_sizes} with target <= " + f"{args.target_faph:.2f} false accepts/hour" + ) + + config = _load_config(config_path) + config["flags"] = config.get("flags", {}) + handler = FeatureHandler(config) + + positive_mode, ambient_mode, positive_tracks, ambient_tracks = _load_eval_sets( + handler, config + ) + + print( + f"→ Using {positive_mode} positives ({len(positive_tracks)}) and " + f"{ambient_mode} ambient tracks ({len(ambient_tracks)})" + ) + + model = Model(str(model_path), stride=config["stride"]) + positive_predictions = _predict_tracks(model, positive_tracks, "positive") + ambient_predictions = _predict_tracks(model, ambient_tracks, "ambient") + + candidates: list[dict[str, float]] = [] + best_by_window: list[dict[str, float]] = [] + step_seconds = config["window_step_ms"] / 1000.0 + + for window_size in window_sizes: + ambient_averages = [ + _moving_average(track, window_size) for track in ambient_predictions + ] + positive_maxima = [] + for track in positive_predictions: + search = ( + track[args.positive_skip_slices :] + if track.size > args.positive_skip_slices + else track + ) + averaged = _moving_average(search, window_size) + if averaged.size == 0: + averaged = _moving_average(track, window_size) + positive_maxima.append(float(np.max(averaged)) if averaged.size else 0.0) + + positive_maxima_array = np.asarray(positive_maxima, dtype=np.float32) + recall_by_cutoff = np.mean( + positive_maxima_array[None, :] > cutoffs[:, None], axis=1 + ) + faph_by_cutoff, ambient_hours = _compute_false_accepts_per_hour( + ambient_averages, + cutoffs, + args.cooldown_slices, + stride=config["stride"], + step_seconds=step_seconds, + ) + + window_candidates = [] + for cutoff, recall, faph in zip(cutoffs, recall_by_cutoff, faph_by_cutoff): + candidate = { + "probability_cutoff": float(round(float(cutoff), 2)), + "sliding_window_size": int(window_size), + "recall": float(recall), + "false_accepts_per_hour": float(faph), + "ambient_hours": float(ambient_hours), + } + candidates.append(candidate) + window_candidates.append(candidate) + + best_window, _ = _select_best_candidate(window_candidates, args.target_faph) + best_by_window.append(best_window) + print( + " window={window}: cutoff={cutoff:.2f}; recall={recall:.2%}; " + "ambient_faph={faph:.3f}".format( + window=window_size, + cutoff=best_window["probability_cutoff"], + recall=best_window["recall"], + faph=best_window["false_accepts_per_hour"], + ) + ) + + best, selected_limit = _select_best_candidate(candidates, args.target_faph) + if best["false_accepts_per_hour"] > args.target_faph + 1e-9: + print( + "⚠️ No candidate met the target false accepts/hour budget; " + "using the best fallback operating point." + ) + + print( + "✓ Selected cutoff={cutoff:.2f}, window={window}, recall={recall:.2%}, " + "ambient_faph={faph:.3f}".format( + cutoff=best["probability_cutoff"], + window=best["sliding_window_size"], + recall=best["recall"], + faph=best["false_accepts_per_hour"], + ) + ) + + output = { + "probability_cutoff": best["probability_cutoff"], + "sliding_window_size": best["sliding_window_size"], + "target_false_accepts_per_hour": float(args.target_faph), + "selected_false_accepts_per_hour_limit": ( + None if math.isinf(selected_limit) else float(selected_limit) + ), + "selected_metrics": { + "recall": round(best["recall"], 6), + "false_accepts_per_hour": round(best["false_accepts_per_hour"], 6), + "ambient_hours": round(best["ambient_hours"], 6), + }, + "evaluation": { + "positive_dataset": positive_mode, + "ambient_dataset": ambient_mode, + "positive_tracks": len(positive_tracks), + "ambient_tracks": len(ambient_tracks), + "cooldown_slices": int(args.cooldown_slices), + "positive_skip_slices": int(args.positive_skip_slices), + "window_sizes": window_sizes, + "cutoff_min": round(float(cutoffs[0]), 4), + "cutoff_max": round(float(cutoffs[-1]), 4), + "cutoff_step": float(args.cutoff_step), + }, + "per_window_best": best_by_window, + "generated_at": datetime.now(timezone.utc).isoformat(), + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(output, indent=2) + "\n", encoding="utf-8") + print(f"📝 Wrote calibration to {output_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/cli/setup_audioset b/cli/setup_audioset index 00c62a1..5af3a02 100755 --- a/cli/setup_audioset +++ b/cli/setup_audioset @@ -130,6 +130,73 @@ print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} fa EOF } +converter_from_dataset_api() { + # shellcheck source=/dev/null + source "${DATA_DIR}/.venv/bin/activate" + + python - "${AUDIO16K_DIR}" <<-'EOF' +import sys +from pathlib import Path + +import librosa +import numpy as np +import scipy.io.wavfile +from datasets import load_dataset + +def write_wav(dst: Path, data: np.ndarray, sr: int): + dst.parent.mkdir(parents=True, exist_ok=True) + x = np.clip(data, -1.0, 1.0) + scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16)) + +audioset_out = Path(sys.argv[1]) + +print(" AudioSet FLAC tarballs are unavailable; using Hugging Face datasets API instead.") +dataset = load_dataset( + "agkphysics/AudioSet", + "balanced", + split="train", + streaming=True, +) + +audioset_bad = [] +ok = 0 +skipped = 0 +heartbeat_every = 250 + +for idx, sample in enumerate(dataset, start=1): + try: + video_id = str(sample.get("video_id") or f"audioset_{idx:06d}") + outfile = audioset_out / f"{video_id}.wav" + if outfile.exists(): + skipped += 1 + continue + + audio = sample.get("audio") or {} + y = np.asarray(audio.get("array")) + sr = int(audio.get("sampling_rate") or 0) + if y.size == 0 or sr <= 0: + raise ValueError("missing decoded audio") + if y.ndim > 1: + y = np.mean(y, axis=-1) + if sr != 16000: + y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=16000) + if y.size == 0: + raise ValueError("empty audio") + write_wav(outfile, y, 16000) + ok += 1 + except Exception as exc: + audioset_bad.append(f"{sample.get('video_id', idx)}:{exc}") + + if idx == 1 or (idx % heartbeat_every) == 0: + print(f" AudioSet API progress: {idx} clips processed (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})") + +if audioset_bad: + (audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad)) + +print(f" AudioSet complete via datasets API ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed)") +EOF +} + expected_filecount=$(get_total_filecount filecounts) actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || : write_filecount=false @@ -139,40 +206,44 @@ if [ "${actual_filecount}" -ne 0 ] ; then echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert" else dl=$(find_rev) - [ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; } - rev=${dl%%,*} - pattern=${dl##*,} + if [ -z "$dl" ] ; then + rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || : + converter_from_dataset_api + else + rev=${dl%%,*} + pattern=${dl##*,} - echo " Checking 10 tarballs" - for i in {0..9} ; do - fname="downloads/bal_train0${i}.tar" - if [ ! -f "${fname}" ] ; then - echo " Downloading bal_train0${i}.tar" - url="${AUDIO_URL}/${rev}/${pattern}${i}.tar" - curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; } + echo " Checking 10 tarballs" + for i in {0..9} ; do + fname="downloads/bal_train0${i}.tar" + if [ ! -f "${fname}" ] ; then + echo " Downloading bal_train0${i}.tar" + url="${AUDIO_URL}/${rev}/${pattern}${i}.tar" + curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; } + fi + + tarball_filecount=$(tar -tvf "${fname}" | wc -l ) + filecounts["bal_train0${i}.tar"]=${tarball_filecount} + write_filecount=true + + echo " Untarring bal_train0${i}.tar" + tar -xf "${fname}" -C "${AUDIO_DIR}" + if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then + echo " Cleaning up bal_train0${i}.tar" + rm -rf "${fname}" + fi + done + + rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || : + converter + + # Recompute counts and warn (but do not fail) + expected_filecount=$(get_total_filecount filecounts) + actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || : + if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then + echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2 + echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2 fi - - tarball_filecount=$(tar -tvf "${fname}" | wc -l ) - filecounts["bal_train0${i}.tar"]=${tarball_filecount} - write_filecount=true - - echo " Untarring bal_train0${i}.tar" - tar -xf "${fname}" -C "${AUDIO_DIR}" - if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then - echo " Cleaning up bal_train0${i}.tar" - rm -rf "${fname}" - fi - done - - rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || : - converter - - # Recompute counts and warn (but do not fail) - expected_filecount=$(get_total_filecount filecounts) - actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || : - if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then - echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2 - echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2 fi fi @@ -196,4 +267,4 @@ if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ] ; then fi echo " Audioset complete" -exit 0 \ No newline at end of file +exit 0 diff --git a/cli/setup_fma b/cli/setup_fma index fe7f090..f9bbca5 100755 --- a/cli/setup_fma +++ b/cli/setup_fma @@ -27,8 +27,11 @@ cd "${DATA_DIR}/training_datasets" echo "***** Checking FMA *****" -AUDIO_URL="https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip" -AUDIO_ZIPFILE="fma_xs.zip" +AUDIO_URLS=( + "https://os.unil.cloud.switch.ch/fma/fma_small.zip" + "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip" +) +AUDIO_ZIPFILE="fma_small.zip" AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}" AUDIO_DIR="fma" mkdir -p "${AUDIO_DIR}" || : @@ -81,6 +84,52 @@ EOF } +extract_zip_with_python() { + local zip_path="$1" + local dest_dir="$2" + + "${DATA_DIR}/.venv/bin/python" - "${zip_path}" "${dest_dir}" <<-'EOF' +import sys +import zipfile +from pathlib import Path +from tqdm import tqdm + +zip_path = Path(sys.argv[1]) +dest_dir = Path(sys.argv[2]) + +if (not zip_path.exists()) or zip_path.stat().st_size == 0: + raise SystemExit(f"Archive missing or empty: {zip_path}") + +with zipfile.ZipFile(zip_path, "r") as zf: + members = zf.infolist() + size_gb = zip_path.stat().st_size / (1024 ** 3) + print(f" Extracting {zip_path.name} ({len(members)} entries, {size_gb:.1f} GiB)...") + for member in tqdm(members, desc=" FMA zip extract", unit="file"): + zf.extract(member, dest_dir) +EOF +} + +download_with_fallbacks() { + local output="$1" + shift + local urls=( "$@" ) + local rc=1 + + for url in "${urls[@]}" ; do + for attempt in 1 2 3 4 ; do + curl -sfL "${url}" -o "${output}" && [ -s "${output}" ] && return 0 + rc=$? + rm -f "${output}" || : + if [ "${attempt}" -lt 4 ] ; then + echo " Retry ${attempt}/3 after download failure" + sleep $(( attempt * 2 )) + fi + done + done + + return "${rc}" +} + expected_filecount=${filecounts[${AUDIO_ZIPFILE}]} actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || : write_filecount=false @@ -92,13 +141,16 @@ else if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then if [ ! -f "${AUDIO_ZIP}" ] ; then echo " Downloading ${AUDIO_ZIPFILE}" - curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}" + download_with_fallbacks "${AUDIO_ZIP}" "${AUDIO_URLS[@]}" || { + echo " Failed to download ${AUDIO_ZIPFILE} from all configured sources." >&2 + exit 1 + } fi rm -rf "${AUDIO_DIR}" || : mkdir "${AUDIO_DIR}" - echo " Unzipping ${AUDIO_ZIPFILE}" - unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}" + echo " Extracting ${AUDIO_ZIPFILE}" + extract_zip_with_python "${AUDIO_ZIP}" "${AUDIO_DIR}" fi if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then echo " Cleaning up ${AUDIO_ZIPFILE}" @@ -128,4 +180,3 @@ fi echo " FMA complete" exit 0 - diff --git a/cli/setup_python_venv b/cli/setup_python_venv index 5ea750a..1122287 100755 --- a/cli/setup_python_venv +++ b/cli/setup_python_venv @@ -242,29 +242,7 @@ if [ ! -s "${MODEL_FILE}.json" ] ; then curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json" fi -# --- Dutch ONNX voices (single-speaker, used with --language=nl) --- -# Working Dutch voices: pim, ronnie (nl_NL) and nathalie (nl_BE). -# nl_NL-mls-medium is intentionally excluded (known Piper issue: outputs gibberish). -HF_VOICES="https://huggingface.co/rhasspy/piper-voices/resolve/main" -declare -a NL_VOICES=( - "nl/nl_NL/pim/medium/nl_NL-pim-medium" - "nl/nl_NL/ronnie/medium/nl_NL-ronnie-medium" - "nl/nl_BE/nathalie/medium/nl_BE-nathalie-medium" -) -echo " ===== Checking Dutch Piper voices =====" -for voice_path in "${NL_VOICES[@]}" ; do - voice_name="$(basename "${voice_path}")" - onnx_file="${VOICES_DIR}/${voice_name}.onnx" - json_file="${VOICES_DIR}/${voice_name}.onnx.json" - if [ ! -f "${onnx_file}" ] ; then - echo " Downloading ${voice_name}.onnx" - curl -sfL "${HF_VOICES}/${voice_path}.onnx?download=true" -o "${onnx_file}" - fi - if [ ! -f "${json_file}" ] ; then - echo " Downloading ${voice_name}.onnx.json" - curl -sfL "${HF_VOICES}/${voice_path}.onnx.json?download=true" -o "${json_file}" - fi -done +echo " Non-English Piper voices will be downloaded on demand for the selected language." ${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu="" echo " ===== Installing onnxruntime${onnxgpu} =====" diff --git a/cli/wake_word_sample_trainer b/cli/wake_word_sample_trainer index 0ce01d8..f4fcda2 100644 --- a/cli/wake_word_sample_trainer +++ b/cli/wake_word_sample_trainer @@ -317,6 +317,7 @@ fi TRAINING_DONE="false" +echo "🏋️ Starting model training and TFLite export (this is the longest stage)…" if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then echo "✅ Training complete (GPU path)." TRAINING_DONE="true" @@ -386,12 +387,24 @@ if [ "${TRAINING_DONE}" != "true" ]; then fi source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite" +calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json" if [ ! -f "${source_path}" ] ; then echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}" exit 1 fi +echo "🎯 Calibrating detector settings for on-device use…" +if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \ + --training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \ + --model "${source_path}" \ + --output "${calibration_path}"; then + echo "✅ Detector calibration complete." +else + echo "⚠️ Detector calibration failed; packaging with default detector settings." + rm -f "${calibration_path}" || : +fi + cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || : cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || : cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || : @@ -404,24 +417,49 @@ tflite_path="${OUTPUT_DIR}/${tflite_filename}" cp "${source_path}" "${tflite_path}" json_path="${OUTPUT_DIR}/${wake_word_filename}.json" -cat <<-EOF > "${json_path}" -{ +export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}" +echo "📦 Packaging final model artifacts…" +"${PYTHON_BIN:-python}" - <<'PY' +import json +import os +from pathlib import Path + +json_path = Path(os.environ["JSON_PATH"]) +calibration_path = Path(os.environ.get("CALIBRATION_PATH", "")) +language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower() +probability_cutoff = 0.97 +sliding_window_size = 5 + +if calibration_path.exists(): + try: + calibration = json.loads(calibration_path.read_text(encoding="utf-8")) + probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff)) + sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size)) + print( + f"🎯 Using calibrated detector settings: " + f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}" + ) + except Exception as exc: + print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.") + +meta = { "type": "micro", - "wake_word": "${WAKE_WORD_TITLE}", + "wake_word": os.environ["WAKE_WORD_TITLE"], "author": "Tater Totterson", "website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git", - "model": "${tflite_filename}", - "trained_languages": ["en"], + "model": os.environ["TFLITE_FILENAME"], + "trained_languages": [language], "version": 2, "micro": { - "probability_cutoff": 0.97, - "sliding_window_size": 5, + "probability_cutoff": round(probability_cutoff, 2), + "sliding_window_size": sliding_window_size, "feature_step_size": 10, "tensor_arena_size": 30000, - "minimum_esphome_version": "2024.7.0" - } + "minimum_esphome_version": "2024.7.0", + }, } -EOF +json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8") +PY echo "Name: ${WAKE_WORD_TITLE}" echo "Model: ${tflite_path}" diff --git a/dockerfile b/dockerfile index bf3bfeb..37bb910 100644 --- a/dockerfile +++ b/dockerfile @@ -22,7 +22,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/ # Root-level entrypoints COPY --chown=root:root --chmod=0755 \ train_wake_word \ - run_recorder.sh \ + run.sh \ trainer_server.py \ requirements.txt \ /root/mww-scripts/ @@ -37,4 +37,4 @@ RUN chmod -R a+x /root/mww-scripts/cli COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html # trainer server -CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"] +CMD ["/bin/bash", "-lc", "/root/mww-scripts/run.sh"] diff --git a/run_recorder.sh b/run.sh similarity index 97% rename from run_recorder.sh rename to run.sh index 75d1ec9..1ef1fa8 100644 --- a/run_recorder.sh +++ b/run.sh @@ -8,7 +8,7 @@ DATA_DIR="${DATA_DIR:-/data}" HOST="${REC_HOST:-0.0.0.0}" PORT="${REC_PORT:-8888}" -# Keep recorder deps separate from training venv +# Keep trainer UI deps separate from the training venv VENV_DIR="${DATA_DIR}/.recorder-venv" PY="${VENV_DIR}/bin/python" PIP="${PY} -m pip" diff --git a/trainer_server.py b/trainer_server.py index 4e9eb70..cb69cf3 100644 --- a/trainer_server.py +++ b/trainer_server.py @@ -42,6 +42,12 @@ PIPER_VOICES_ROOT_URL = os.environ.get( "https://huggingface.co/rhasspy/piper-voices/resolve/main", ) PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900")) +PIPER_CATALOG_CACHE_FILE = Path( + os.environ.get( + "PIPER_CATALOG_CACHE_FILE", + str(ROOT_DIR / ".cache" / "piper_voices_catalog.json"), + ) +).resolve() DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y") DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y") @@ -177,6 +183,27 @@ def _fetch_piper_catalog() -> Optional[Dict[str, Any]]: return data if isinstance(data, dict) else None +def _read_cached_piper_catalog_file() -> Optional[Dict[str, Any]]: + try: + if not PIPER_CATALOG_CACHE_FILE.exists(): + return None + data = json.loads(PIPER_CATALOG_CACHE_FILE.read_text(encoding="utf-8")) + return data if isinstance(data, dict) else None + except Exception: + return None + + +def _write_cached_piper_catalog_file(data: Dict[str, Any]): + try: + PIPER_CATALOG_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) + PIPER_CATALOG_CACHE_FILE.write_text( + json.dumps(data, ensure_ascii=True), + encoding="utf-8", + ) + except Exception: + pass + + def _load_piper_catalog() -> Optional[Dict[str, Any]]: now = time.time() with PIPER_CATALOG_LOCK: @@ -185,6 +212,8 @@ def _load_piper_catalog() -> Optional[Dict[str, Any]]: if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS: return cached + disk_cached = _read_cached_piper_catalog_file() + try: fresh = _fetch_piper_catalog() except Exception: @@ -194,9 +223,15 @@ def _load_piper_catalog() -> Optional[Dict[str, Any]]: if fresh is not None: PIPER_CATALOG_CACHE["entries"] = fresh PIPER_CATALOG_CACHE["fetched_at"] = now + _write_cached_piper_catalog_file(fresh) return fresh - if PIPER_CATALOG_CACHE.get("entries") is None: - PIPER_CATALOG_CACHE["entries"] = {} + if PIPER_CATALOG_CACHE.get("entries") is not None: + return PIPER_CATALOG_CACHE.get("entries") + if disk_cached is not None: + PIPER_CATALOG_CACHE["entries"] = disk_cached + PIPER_CATALOG_CACHE["fetched_at"] = now + return disk_cached + PIPER_CATALOG_CACHE["entries"] = {} PIPER_CATALOG_CACHE["fetched_at"] = now return PIPER_CATALOG_CACHE.get("entries")