mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51cbf6fd90 | ||
|
|
6e7396455a | ||
|
|
2da9f7a686 | ||
|
|
7b028e4420 | ||
|
|
b3d9f0e369 | ||
|
|
240ca7682e | ||
|
|
18e5fcd000 |
251
README.md
251
README.md
@@ -1,25 +1,20 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<h1>🎙️ microWakeWord Nvidia Trainer & Recorder</h1>
|
<h1>microWakeWord NVIDIA Docker Trainer UI</h1>
|
||||||
<img width="1002" height="593" alt="Screenshot 2026-01-18 at 8 13 35 AM" src="https://github.com/user-attachments/assets/e1411d8a-8638-4df8-992b-09a46c6e5ddc" />
|
<img width="800" alt="Screenshot 2026-04-14 at 11 02 06 PM" src="https://github.com/user-attachments/assets/694f4cb7-e4d8-4e2b-80ec-b40fb41cbfff" />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
Train **microWakeWord** detection models using a simple **web-based recorder + trainer UI**, packaged in a Docker container.
|
Train custom microWakeWord models in Docker with:
|
||||||
|
|
||||||
No Jupyter notebooks required. No manual cell execution. Just record your voice (optional) and train.
|
- uploaded personal voice samples
|
||||||
|
- automatically generated Piper TTS samples
|
||||||
|
- a browser-based trainer UI
|
||||||
|
- live training logs in a popup console
|
||||||
|
|
||||||
|
This project no longer records audio in the browser. The UI is now upload-first: users add their own audio files, the app validates or converts them, and training runs from the same page.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<img width="100" height="44" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/87351bed-3321-4a43-924f-fecf2e4e700f" />
|
## Docker Image
|
||||||
|
|
||||||
**microWakeWord_Trainer-Nvidia** is available in the **Unraid Community Apps** store.
|
|
||||||
Install directly from the Unraid App Store with a one-click template.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<img width="100" height="56" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/bf959585-ae13-4b4d-ae62-4202a850d35a" />
|
|
||||||
|
|
||||||
|
|
||||||
### Pull the Docker Image
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker pull ghcr.io/tatertotterson/microwakeword:latest
|
docker pull ghcr.io/tatertotterson/microwakeword:latest
|
||||||
@@ -27,7 +22,7 @@ docker pull ghcr.io/tatertotterson/microwakeword:latest
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### Run the Container
|
## Run The Container
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -d \
|
docker run -d \
|
||||||
@@ -37,133 +32,143 @@ docker run -d \
|
|||||||
ghcr.io/tatertotterson/microwakeword:latest
|
ghcr.io/tatertotterson/microwakeword:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
**What these flags do:**
|
What these flags do:
|
||||||
- `--gpus all` → Enables GPU acceleration
|
|
||||||
- `-p 8888:8888` → Exposes the Recorder + Trainer WebUI
|
|
||||||
- `-v $(pwd):/data` → Persists all models, datasets, and cache
|
|
||||||
|
|
||||||
---
|
- `--gpus all` enables GPU acceleration
|
||||||
|
- `-p 8888:8888` exposes the trainer UI
|
||||||
|
- `-v $(pwd):/data` persists models, downloaded voices, datasets, and personal samples
|
||||||
|
|
||||||
### Open the Recorder WebUI
|
Then open:
|
||||||
|
|
||||||
Open your browser and go to:
|
|
||||||
|
|
||||||
👉 **http://localhost:8888**
|
|
||||||
|
|
||||||
You’ll see the **microWakeWord Recorder & Trainer UI**.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎤 Recording Voice Samples (Optional)
|
|
||||||
|
|
||||||
Personal voice recordings are **optional**.
|
|
||||||
|
|
||||||
- You may **record your own voice** for better accuracy
|
|
||||||
- Or simply **click “Train” without recording anything**
|
|
||||||
|
|
||||||
If no recordings are present, training will proceed using **synthetic TTS samples only**.
|
|
||||||
|
|
||||||
### Remote systems (important)
|
|
||||||
If you are running this on a **remote PC / server**, browser-based recording will not work unless:
|
|
||||||
- You use a **reverse proxy** (HTTPS + mic permissions), **or**
|
|
||||||
- You access the UI via **localhost** on the same machine
|
|
||||||
|
|
||||||
Training itself works fine remotely — only recording requires local microphone access.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 🎙️ Recording Flow
|
|
||||||
|
|
||||||
1. Enter your wake word
|
|
||||||
2. Test pronunciation with **Test TTS**
|
|
||||||
3. Choose:
|
|
||||||
- Number of speakers (e.g. family members)
|
|
||||||
- Takes per speaker (default: 10)
|
|
||||||
4. Click **Begin recording**
|
|
||||||
5. Speak naturally — recording:
|
|
||||||
- Starts when you talk
|
|
||||||
- Stops automatically after silence
|
|
||||||
6. Repeat for each speaker
|
|
||||||
|
|
||||||
Files are saved automatically to:
|
|
||||||
|
|
||||||
```
|
|
||||||
personal_samples/
|
|
||||||
speaker01_take01.wav
|
|
||||||
speaker01_take02.wav
|
|
||||||
speaker02_take01.wav
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🧠 Training Behavior (Important Notes)
|
|
||||||
|
|
||||||
### ⏬ First training run
|
|
||||||
The **first time you click Train**, the system will download **large training datasets** (background noise, speech corpora, etc.).
|
|
||||||
|
|
||||||
- This can take **several minutes**
|
|
||||||
- This happens **only once**
|
|
||||||
- Data is cached inside `/data`
|
|
||||||
|
|
||||||
You **will NOT need to download these again** unless you delete `/data`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 🔁 Re-training is safe and incremental
|
|
||||||
|
|
||||||
- You can train **multiple wake words** back-to-back
|
|
||||||
- You do **NOT** need to clear any folders between runs
|
|
||||||
- Old models are preserved in timestamped output directories
|
|
||||||
- All required cleanup and reuse logic is handled automatically
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📦 Output Files
|
|
||||||
|
|
||||||
When training completes, you’ll get:
|
|
||||||
- `<wake_word>.tflite` – quantized streaming model
|
|
||||||
- `<wake_word>.json` – ESPHome-compatible metadata
|
|
||||||
|
|
||||||
Both are saved under:
|
|
||||||
|
|
||||||
```text
|
```text
|
||||||
/data/output/
|
http://localhost:8888
|
||||||
```
|
```
|
||||||
|
|
||||||
Each run is placed in its own timestamped folder.
|
---
|
||||||
|
|
||||||
|
## What The UI Does
|
||||||
|
|
||||||
|
- Start a wake word session
|
||||||
|
- Test TTS pronunciation
|
||||||
|
- Upload one or many personal samples
|
||||||
|
- Normalize uploads to `16 kHz / mono / 16-bit PCM WAV`
|
||||||
|
- Train with or without personal samples
|
||||||
|
- Show a popup console with live progress and logs
|
||||||
|
|
||||||
|
Personal samples are optional. If none are uploaded, the trainer can still proceed with TTS-only data after confirmation.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🎤 Optional: Personal Voice Samples (Advanced)
|
## Personal Samples
|
||||||
|
|
||||||
If you record personal samples:
|
Accepted upload formats include:
|
||||||
- They are automatically augmented
|
|
||||||
- They are **up-weighted during training**
|
|
||||||
- This significantly improves real-world accuracy
|
|
||||||
|
|
||||||
No configuration required — detection is automatic.
|
- WAV
|
||||||
|
- MP3
|
||||||
|
- M4A
|
||||||
|
- FLAC
|
||||||
|
- OGG
|
||||||
|
- AAC
|
||||||
|
- OPUS
|
||||||
|
- WEBM
|
||||||
|
|
||||||
|
The backend validates or converts uploads with `ffmpeg` and stores the normalized files in:
|
||||||
|
|
||||||
|
```text
|
||||||
|
/data/personal_samples/
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- starting a new session does not clear personal samples
|
||||||
|
- use the `Clear personal samples` button if you want to wipe them
|
||||||
|
- any uploaded personal samples are automatically included in training
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🔄 Resetting Everything (Optional)
|
## Language Support
|
||||||
|
|
||||||
If you want a **completely clean slate**:
|
The language selector is dynamic.
|
||||||
|
|
||||||
Delete the /data folder
|
- `en` is always available
|
||||||
|
- non-English languages are populated from Piper voice metadata
|
||||||
|
- when you train with a non-English language, the backend downloads all Piper ONNX voices for that selected language only
|
||||||
|
- it does not pre-download every language
|
||||||
|
- already-downloaded voices are reused on later runs
|
||||||
|
|
||||||
Then restart the container.
|
English stays on its existing dedicated generator model path. Non-English languages use the selected language's ONNX Piper voices.
|
||||||
|
|
||||||
⚠️ This will:
|
If the Piper catalog is unavailable, already-installed local voices can still be used.
|
||||||
- Remove cached datasets
|
|
||||||
- Require re-downloading training data
|
|
||||||
- Delete trained models
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🙌 Credits
|
## Training Behavior
|
||||||
|
|
||||||
Built on top of the excellent
|
1. Enter the wake word
|
||||||
**https://github.com/kahrendt/microWakeWord**
|
2. Optionally test pronunciation
|
||||||
|
3. Optionally upload personal samples
|
||||||
|
4. Click `Start training`
|
||||||
|
5. Watch the popup console for:
|
||||||
|
- selected-language voice downloads when needed
|
||||||
|
- sample generation progress
|
||||||
|
- dataset setup
|
||||||
|
- training progress and completion
|
||||||
|
|
||||||
Huge thanks to the original authors ❤️
|
The `Open console` button lets you reopen the log window after closing it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## First Run Notes
|
||||||
|
|
||||||
|
The first real training run may download large training assets into `/data`, such as:
|
||||||
|
|
||||||
|
- Piper voices for the selected language
|
||||||
|
- training datasets and background data
|
||||||
|
- Python training environment dependencies
|
||||||
|
|
||||||
|
These are reused later unless you delete `/data`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Output Files
|
||||||
|
|
||||||
|
Successful runs produce:
|
||||||
|
|
||||||
|
```text
|
||||||
|
/data/output/<wake_word>.tflite
|
||||||
|
/data/output/<wake_word>.json
|
||||||
|
```
|
||||||
|
|
||||||
|
If those files already exist, the trainer creates timestamped backups before replacing them.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Resetting Everything
|
||||||
|
|
||||||
|
If you want a clean slate, stop the container and remove the contents of your mounted `/data` directory.
|
||||||
|
|
||||||
|
That will remove:
|
||||||
|
|
||||||
|
- personal samples
|
||||||
|
- downloaded Piper voices
|
||||||
|
- cached datasets
|
||||||
|
- training environments
|
||||||
|
- trained models
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- browser microphone recording has been removed
|
||||||
|
- personal samples are optional
|
||||||
|
- the server module is now `trainer_server.py`
|
||||||
|
- the launcher script is now `run.sh`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Credits
|
||||||
|
|
||||||
|
Built on top of:
|
||||||
|
|
||||||
|
- [microWakeWord](https://github.com/kahrendt/microWakeWord)
|
||||||
|
- [piper-sample-generator](https://github.com/rhasspy/piper-sample-generator)
|
||||||
|
|||||||
405
cli/calibrate_detector.py
Normal file
405
cli/calibrate_detector.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Choose detector metadata that better matches the trained model."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from microwakeword.data import FeatureHandler
|
||||||
|
from microwakeword.inference import Model
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WINDOW_SIZES = [3, 4, 5, 6, 7]
|
||||||
|
DEFAULT_TARGET_FAPH = float(os.environ.get("MWW_CALIBRATION_TARGET_FAPH", "1.0"))
|
||||||
|
DEFAULT_COOLDOWN_SLICES = int(os.environ.get("MWW_CALIBRATION_COOLDOWN_SLICES", "25"))
|
||||||
|
DEFAULT_POSITIVE_SKIP_SLICES = int(
|
||||||
|
os.environ.get("MWW_CALIBRATION_POSITIVE_SKIP_SLICES", "25")
|
||||||
|
)
|
||||||
|
DEFAULT_CUTOFF_STEP = float(os.environ.get("MWW_CALIBRATION_CUTOFF_STEP", "0.01"))
|
||||||
|
DEFAULT_CUTOFF_MIN = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MIN", "0.00"))
|
||||||
|
DEFAULT_CUTOFF_MAX = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MAX", "1.00"))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Calibrate microWakeWord detector metadata from validation data."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training-config",
|
||||||
|
default="trained_models/wakeword/training_config.yaml",
|
||||||
|
help="Path to the saved microWakeWord training_config.yaml file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default=(
|
||||||
|
"trained_models/wakeword/tflite_stream_state_internal_quant/"
|
||||||
|
"stream_state_internal_quant.tflite"
|
||||||
|
),
|
||||||
|
help="Path to the quantized streaming TFLite model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
default=(
|
||||||
|
"trained_models/wakeword/tflite_stream_state_internal_quant/"
|
||||||
|
"detection_calibration.json"
|
||||||
|
),
|
||||||
|
help="Where to write the selected detector settings as JSON.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window-sizes",
|
||||||
|
default=",".join(str(value) for value in DEFAULT_WINDOW_SIZES),
|
||||||
|
help="Comma-separated sliding window sizes to evaluate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-faph",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_TARGET_FAPH,
|
||||||
|
help="Target ambient false accepts per hour for the selected operating point.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cooldown-slices",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_COOLDOWN_SLICES,
|
||||||
|
help="Cooldown slices to use when estimating false accepts per hour.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--positive-skip-slices",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_POSITIVE_SKIP_SLICES,
|
||||||
|
help="Initial streaming slices to ignore when scoring positive examples.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cutoff-step",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_CUTOFF_STEP,
|
||||||
|
help="Cutoff increment to evaluate between cutoff-min and cutoff-max.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cutoff-min",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_CUTOFF_MIN,
|
||||||
|
help="Minimum cutoff to evaluate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cutoff-max",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_CUTOFF_MAX,
|
||||||
|
help="Maximum cutoff to evaluate.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_window_sizes(raw: str) -> list[int]:
|
||||||
|
values = []
|
||||||
|
for item in (raw or "").split(","):
|
||||||
|
item = item.strip()
|
||||||
|
if not item:
|
||||||
|
continue
|
||||||
|
value = int(item)
|
||||||
|
if value < 1:
|
||||||
|
raise ValueError("window sizes must be >= 1")
|
||||||
|
values.append(value)
|
||||||
|
if not values:
|
||||||
|
raise ValueError("at least one window size is required")
|
||||||
|
return sorted(set(values))
|
||||||
|
|
||||||
|
|
||||||
|
def _moving_average(values: Sequence[float], window_size: int) -> np.ndarray:
|
||||||
|
array = np.asarray(values, dtype=np.float32)
|
||||||
|
if array.size == 0:
|
||||||
|
return array
|
||||||
|
if window_size <= 1:
|
||||||
|
return array
|
||||||
|
if array.size < window_size:
|
||||||
|
return np.asarray([float(array.mean())], dtype=np.float32)
|
||||||
|
cumsum = np.cumsum(np.insert(array, 0, 0.0))
|
||||||
|
averaged = (cumsum[window_size:] - cumsum[:-window_size]) / float(window_size)
|
||||||
|
return averaged.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_false_accepts_per_hour(
|
||||||
|
probabilities_per_track: Iterable[np.ndarray],
|
||||||
|
cutoffs: np.ndarray,
|
||||||
|
cooldown_slices: int,
|
||||||
|
stride: int,
|
||||||
|
step_seconds: float,
|
||||||
|
) -> tuple[np.ndarray, float]:
|
||||||
|
cutoffs = np.asarray(cutoffs, dtype=np.float32)
|
||||||
|
false_accepts = np.zeros(cutoffs.shape[0], dtype=np.float64)
|
||||||
|
duration_hours = 0.0
|
||||||
|
|
||||||
|
for track_probabilities in probabilities_per_track:
|
||||||
|
if track_probabilities.size == 0:
|
||||||
|
continue
|
||||||
|
duration_hours += (
|
||||||
|
len(track_probabilities) * stride * step_seconds / 3600.0
|
||||||
|
)
|
||||||
|
cooldown = np.full(cutoffs.shape[0], cooldown_slices, dtype=np.int32)
|
||||||
|
for probability in track_probabilities:
|
||||||
|
cooldown = np.maximum(cooldown - 1, 0)
|
||||||
|
accepted = (cooldown == 0) & (probability > cutoffs)
|
||||||
|
false_accepts += accepted.astype(np.float64)
|
||||||
|
cooldown = np.where(accepted, cooldown_slices, cooldown)
|
||||||
|
|
||||||
|
if duration_hours <= 0:
|
||||||
|
return np.full(cutoffs.shape[0], math.inf, dtype=np.float64), 0.0
|
||||||
|
|
||||||
|
return false_accepts / duration_hours, duration_hours
|
||||||
|
|
||||||
|
|
||||||
|
def _select_best_candidate(
|
||||||
|
candidates: list[dict[str, float]],
|
||||||
|
target_faph: float,
|
||||||
|
) -> tuple[dict[str, float], float]:
|
||||||
|
fallback_limits = [
|
||||||
|
target_faph,
|
||||||
|
max(target_faph * 2.0, target_faph + 0.5),
|
||||||
|
max(target_faph * 4.0, 2.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
def tier(candidate: dict[str, float]) -> int:
|
||||||
|
for index, limit in enumerate(fallback_limits):
|
||||||
|
if candidate["false_accepts_per_hour"] <= limit + 1e-9:
|
||||||
|
return index
|
||||||
|
return len(fallback_limits)
|
||||||
|
|
||||||
|
best = min(
|
||||||
|
candidates,
|
||||||
|
key=lambda candidate: (
|
||||||
|
tier(candidate),
|
||||||
|
-candidate["recall"],
|
||||||
|
candidate["false_accepts_per_hour"],
|
||||||
|
abs(candidate["sliding_window_size"] - 5),
|
||||||
|
-candidate["probability_cutoff"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tier_index = tier(best)
|
||||||
|
if tier_index < len(fallback_limits):
|
||||||
|
return best, fallback_limits[tier_index]
|
||||||
|
return best, float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config(config_path: Path) -> dict:
|
||||||
|
with config_path.open("r", encoding="utf-8") as handle:
|
||||||
|
return yaml.load(handle.read(), Loader=yaml.Loader)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_eval_sets(
|
||||||
|
handler: FeatureHandler,
|
||||||
|
config: dict,
|
||||||
|
) -> tuple[str, str, list[np.ndarray], list[np.ndarray]]:
|
||||||
|
for positive_mode, ambient_mode in (
|
||||||
|
("validation", "validation_ambient"),
|
||||||
|
("testing", "testing_ambient"),
|
||||||
|
):
|
||||||
|
positive_tracks, labels, _ = handler.get_data(
|
||||||
|
positive_mode,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
features_length=config["spectrogram_length"],
|
||||||
|
truncation_strategy="none",
|
||||||
|
)
|
||||||
|
ambient_tracks, _, _ = handler.get_data(
|
||||||
|
ambient_mode,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
features_length=config["spectrogram_length"],
|
||||||
|
truncation_strategy="none",
|
||||||
|
)
|
||||||
|
positives = [
|
||||||
|
np.asarray(track)
|
||||||
|
for track, label in zip(positive_tracks, labels)
|
||||||
|
if bool(label)
|
||||||
|
]
|
||||||
|
ambient = [np.asarray(track) for track in ambient_tracks]
|
||||||
|
if positives and ambient:
|
||||||
|
return positive_mode, ambient_mode, positives, ambient
|
||||||
|
raise RuntimeError(
|
||||||
|
"No suitable validation/testing data was found for detector calibration."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_tracks(
|
||||||
|
model: Model,
|
||||||
|
tracks: Sequence[np.ndarray],
|
||||||
|
label: str,
|
||||||
|
) -> list[np.ndarray]:
|
||||||
|
predictions: list[np.ndarray] = []
|
||||||
|
total = len(tracks)
|
||||||
|
print(f"→ Running streaming inference on {total} {label} track(s)")
|
||||||
|
for index, track in enumerate(tracks, start=1):
|
||||||
|
values = np.asarray(model.predict_spectrogram(track), dtype=np.float32)
|
||||||
|
predictions.append(values)
|
||||||
|
if index == total or index % 25 == 0:
|
||||||
|
print(f" {label}: {index}/{total}")
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
window_sizes = _parse_window_sizes(args.window_sizes)
|
||||||
|
if args.cutoff_step <= 0:
|
||||||
|
raise ValueError("cutoff-step must be > 0")
|
||||||
|
if args.cutoff_max < args.cutoff_min:
|
||||||
|
raise ValueError("cutoff-max must be >= cutoff-min")
|
||||||
|
|
||||||
|
config_path = Path(args.training_config)
|
||||||
|
model_path = Path(args.model)
|
||||||
|
output_path = Path(args.output)
|
||||||
|
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Training config not found: {config_path}")
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Streaming TFLite model not found: {model_path}")
|
||||||
|
|
||||||
|
cutoffs = np.arange(
|
||||||
|
args.cutoff_min,
|
||||||
|
args.cutoff_max + (args.cutoff_step / 2.0),
|
||||||
|
args.cutoff_step,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
cutoffs = np.clip(cutoffs, 0.0, 1.0)
|
||||||
|
cutoffs = np.unique(np.round(cutoffs, 4))
|
||||||
|
|
||||||
|
print("===== Detector Calibration =====")
|
||||||
|
print(f"→ Model: {model_path}")
|
||||||
|
print(f"→ Training config: {config_path}")
|
||||||
|
print(
|
||||||
|
f"→ Evaluating window sizes {window_sizes} with target <= "
|
||||||
|
f"{args.target_faph:.2f} false accepts/hour"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = _load_config(config_path)
|
||||||
|
config["flags"] = config.get("flags", {})
|
||||||
|
handler = FeatureHandler(config)
|
||||||
|
|
||||||
|
positive_mode, ambient_mode, positive_tracks, ambient_tracks = _load_eval_sets(
|
||||||
|
handler, config
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"→ Using {positive_mode} positives ({len(positive_tracks)}) and "
|
||||||
|
f"{ambient_mode} ambient tracks ({len(ambient_tracks)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = Model(str(model_path), stride=config["stride"])
|
||||||
|
positive_predictions = _predict_tracks(model, positive_tracks, "positive")
|
||||||
|
ambient_predictions = _predict_tracks(model, ambient_tracks, "ambient")
|
||||||
|
|
||||||
|
candidates: list[dict[str, float]] = []
|
||||||
|
best_by_window: list[dict[str, float]] = []
|
||||||
|
step_seconds = config["window_step_ms"] / 1000.0
|
||||||
|
|
||||||
|
for window_size in window_sizes:
|
||||||
|
ambient_averages = [
|
||||||
|
_moving_average(track, window_size) for track in ambient_predictions
|
||||||
|
]
|
||||||
|
positive_maxima = []
|
||||||
|
for track in positive_predictions:
|
||||||
|
search = (
|
||||||
|
track[args.positive_skip_slices :]
|
||||||
|
if track.size > args.positive_skip_slices
|
||||||
|
else track
|
||||||
|
)
|
||||||
|
averaged = _moving_average(search, window_size)
|
||||||
|
if averaged.size == 0:
|
||||||
|
averaged = _moving_average(track, window_size)
|
||||||
|
positive_maxima.append(float(np.max(averaged)) if averaged.size else 0.0)
|
||||||
|
|
||||||
|
positive_maxima_array = np.asarray(positive_maxima, dtype=np.float32)
|
||||||
|
recall_by_cutoff = np.mean(
|
||||||
|
positive_maxima_array[None, :] > cutoffs[:, None], axis=1
|
||||||
|
)
|
||||||
|
faph_by_cutoff, ambient_hours = _compute_false_accepts_per_hour(
|
||||||
|
ambient_averages,
|
||||||
|
cutoffs,
|
||||||
|
args.cooldown_slices,
|
||||||
|
stride=config["stride"],
|
||||||
|
step_seconds=step_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
window_candidates = []
|
||||||
|
for cutoff, recall, faph in zip(cutoffs, recall_by_cutoff, faph_by_cutoff):
|
||||||
|
candidate = {
|
||||||
|
"probability_cutoff": float(round(float(cutoff), 2)),
|
||||||
|
"sliding_window_size": int(window_size),
|
||||||
|
"recall": float(recall),
|
||||||
|
"false_accepts_per_hour": float(faph),
|
||||||
|
"ambient_hours": float(ambient_hours),
|
||||||
|
}
|
||||||
|
candidates.append(candidate)
|
||||||
|
window_candidates.append(candidate)
|
||||||
|
|
||||||
|
best_window, _ = _select_best_candidate(window_candidates, args.target_faph)
|
||||||
|
best_by_window.append(best_window)
|
||||||
|
print(
|
||||||
|
" window={window}: cutoff={cutoff:.2f}; recall={recall:.2%}; "
|
||||||
|
"ambient_faph={faph:.3f}".format(
|
||||||
|
window=window_size,
|
||||||
|
cutoff=best_window["probability_cutoff"],
|
||||||
|
recall=best_window["recall"],
|
||||||
|
faph=best_window["false_accepts_per_hour"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
best, selected_limit = _select_best_candidate(candidates, args.target_faph)
|
||||||
|
if best["false_accepts_per_hour"] > args.target_faph + 1e-9:
|
||||||
|
print(
|
||||||
|
"⚠️ No candidate met the target false accepts/hour budget; "
|
||||||
|
"using the best fallback operating point."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"✓ Selected cutoff={cutoff:.2f}, window={window}, recall={recall:.2%}, "
|
||||||
|
"ambient_faph={faph:.3f}".format(
|
||||||
|
cutoff=best["probability_cutoff"],
|
||||||
|
window=best["sliding_window_size"],
|
||||||
|
recall=best["recall"],
|
||||||
|
faph=best["false_accepts_per_hour"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"probability_cutoff": best["probability_cutoff"],
|
||||||
|
"sliding_window_size": best["sliding_window_size"],
|
||||||
|
"target_false_accepts_per_hour": float(args.target_faph),
|
||||||
|
"selected_false_accepts_per_hour_limit": (
|
||||||
|
None if math.isinf(selected_limit) else float(selected_limit)
|
||||||
|
),
|
||||||
|
"selected_metrics": {
|
||||||
|
"recall": round(best["recall"], 6),
|
||||||
|
"false_accepts_per_hour": round(best["false_accepts_per_hour"], 6),
|
||||||
|
"ambient_hours": round(best["ambient_hours"], 6),
|
||||||
|
},
|
||||||
|
"evaluation": {
|
||||||
|
"positive_dataset": positive_mode,
|
||||||
|
"ambient_dataset": ambient_mode,
|
||||||
|
"positive_tracks": len(positive_tracks),
|
||||||
|
"ambient_tracks": len(ambient_tracks),
|
||||||
|
"cooldown_slices": int(args.cooldown_slices),
|
||||||
|
"positive_skip_slices": int(args.positive_skip_slices),
|
||||||
|
"window_sizes": window_sizes,
|
||||||
|
"cutoff_min": round(float(cutoffs[0]), 4),
|
||||||
|
"cutoff_max": round(float(cutoffs[-1]), 4),
|
||||||
|
"cutoff_step": float(args.cutoff_step),
|
||||||
|
},
|
||||||
|
"per_window_best": best_by_window,
|
||||||
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path.write_text(json.dumps(output, indent=2) + "\n", encoding="utf-8")
|
||||||
|
print(f"📝 Wrote calibration to {output_path}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -130,6 +130,73 @@ print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} fa
|
|||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
converter_from_dataset_api() {
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
|
python - "${AUDIO16K_DIR}" <<-'EOF'
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io.wavfile
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
x = np.clip(data, -1.0, 1.0)
|
||||||
|
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||||
|
|
||||||
|
audioset_out = Path(sys.argv[1])
|
||||||
|
|
||||||
|
print(" AudioSet FLAC tarballs are unavailable; using Hugging Face datasets API instead.")
|
||||||
|
dataset = load_dataset(
|
||||||
|
"agkphysics/AudioSet",
|
||||||
|
"balanced",
|
||||||
|
split="train",
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audioset_bad = []
|
||||||
|
ok = 0
|
||||||
|
skipped = 0
|
||||||
|
heartbeat_every = 250
|
||||||
|
|
||||||
|
for idx, sample in enumerate(dataset, start=1):
|
||||||
|
try:
|
||||||
|
video_id = str(sample.get("video_id") or f"audioset_{idx:06d}")
|
||||||
|
outfile = audioset_out / f"{video_id}.wav"
|
||||||
|
if outfile.exists():
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
audio = sample.get("audio") or {}
|
||||||
|
y = np.asarray(audio.get("array"))
|
||||||
|
sr = int(audio.get("sampling_rate") or 0)
|
||||||
|
if y.size == 0 or sr <= 0:
|
||||||
|
raise ValueError("missing decoded audio")
|
||||||
|
if y.ndim > 1:
|
||||||
|
y = np.mean(y, axis=-1)
|
||||||
|
if sr != 16000:
|
||||||
|
y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=16000)
|
||||||
|
if y.size == 0:
|
||||||
|
raise ValueError("empty audio")
|
||||||
|
write_wav(outfile, y, 16000)
|
||||||
|
ok += 1
|
||||||
|
except Exception as exc:
|
||||||
|
audioset_bad.append(f"{sample.get('video_id', idx)}:{exc}")
|
||||||
|
|
||||||
|
if idx == 1 or (idx % heartbeat_every) == 0:
|
||||||
|
print(f" AudioSet API progress: {idx} clips processed (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})")
|
||||||
|
|
||||||
|
if audioset_bad:
|
||||||
|
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
|
||||||
|
|
||||||
|
print(f" AudioSet complete via datasets API ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed)")
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
expected_filecount=$(get_total_filecount filecounts)
|
expected_filecount=$(get_total_filecount filecounts)
|
||||||
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
write_filecount=false
|
write_filecount=false
|
||||||
@@ -139,40 +206,44 @@ if [ "${actual_filecount}" -ne 0 ] ; then
|
|||||||
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
|
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
|
||||||
else
|
else
|
||||||
dl=$(find_rev)
|
dl=$(find_rev)
|
||||||
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; }
|
if [ -z "$dl" ] ; then
|
||||||
rev=${dl%%,*}
|
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
||||||
pattern=${dl##*,}
|
converter_from_dataset_api
|
||||||
|
else
|
||||||
|
rev=${dl%%,*}
|
||||||
|
pattern=${dl##*,}
|
||||||
|
|
||||||
echo " Checking 10 tarballs"
|
echo " Checking 10 tarballs"
|
||||||
for i in {0..9} ; do
|
for i in {0..9} ; do
|
||||||
fname="downloads/bal_train0${i}.tar"
|
fname="downloads/bal_train0${i}.tar"
|
||||||
if [ ! -f "${fname}" ] ; then
|
if [ ! -f "${fname}" ] ; then
|
||||||
echo " Downloading bal_train0${i}.tar"
|
echo " Downloading bal_train0${i}.tar"
|
||||||
url="${AUDIO_URL}/${rev}/${pattern}${i}.tar"
|
url="${AUDIO_URL}/${rev}/${pattern}${i}.tar"
|
||||||
curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; }
|
curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; }
|
||||||
|
fi
|
||||||
|
|
||||||
|
tarball_filecount=$(tar -tvf "${fname}" | wc -l )
|
||||||
|
filecounts["bal_train0${i}.tar"]=${tarball_filecount}
|
||||||
|
write_filecount=true
|
||||||
|
|
||||||
|
echo " Untarring bal_train0${i}.tar"
|
||||||
|
tar -xf "${fname}" -C "${AUDIO_DIR}"
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then
|
||||||
|
echo " Cleaning up bal_train0${i}.tar"
|
||||||
|
rm -rf "${fname}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
||||||
|
converter
|
||||||
|
|
||||||
|
# Recompute counts and warn (but do not fail)
|
||||||
|
expected_filecount=$(get_total_filecount filecounts)
|
||||||
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
|
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||||
|
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
|
||||||
|
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
|
||||||
fi
|
fi
|
||||||
|
|
||||||
tarball_filecount=$(tar -tvf "${fname}" | wc -l )
|
|
||||||
filecounts["bal_train0${i}.tar"]=${tarball_filecount}
|
|
||||||
write_filecount=true
|
|
||||||
|
|
||||||
echo " Untarring bal_train0${i}.tar"
|
|
||||||
tar -xf "${fname}" -C "${AUDIO_DIR}"
|
|
||||||
if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then
|
|
||||||
echo " Cleaning up bal_train0${i}.tar"
|
|
||||||
rm -rf "${fname}"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
|
||||||
converter
|
|
||||||
|
|
||||||
# Recompute counts and warn (but do not fail)
|
|
||||||
expected_filecount=$(get_total_filecount filecounts)
|
|
||||||
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
|
||||||
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
|
||||||
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
|
|
||||||
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -27,8 +27,11 @@ cd "${DATA_DIR}/training_datasets"
|
|||||||
|
|
||||||
echo "***** Checking FMA *****"
|
echo "***** Checking FMA *****"
|
||||||
|
|
||||||
AUDIO_URL="https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
AUDIO_URLS=(
|
||||||
AUDIO_ZIPFILE="fma_xs.zip"
|
"https://os.unil.cloud.switch.ch/fma/fma_small.zip"
|
||||||
|
"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
||||||
|
)
|
||||||
|
AUDIO_ZIPFILE="fma_small.zip"
|
||||||
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||||
AUDIO_DIR="fma"
|
AUDIO_DIR="fma"
|
||||||
mkdir -p "${AUDIO_DIR}" || :
|
mkdir -p "${AUDIO_DIR}" || :
|
||||||
@@ -81,6 +84,52 @@ EOF
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extract_zip_with_python() {
|
||||||
|
local zip_path="$1"
|
||||||
|
local dest_dir="$2"
|
||||||
|
|
||||||
|
"${DATA_DIR}/.venv/bin/python" - "${zip_path}" "${dest_dir}" <<-'EOF'
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
zip_path = Path(sys.argv[1])
|
||||||
|
dest_dir = Path(sys.argv[2])
|
||||||
|
|
||||||
|
if (not zip_path.exists()) or zip_path.stat().st_size == 0:
|
||||||
|
raise SystemExit(f"Archive missing or empty: {zip_path}")
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
|
members = zf.infolist()
|
||||||
|
size_gb = zip_path.stat().st_size / (1024 ** 3)
|
||||||
|
print(f" Extracting {zip_path.name} ({len(members)} entries, {size_gb:.1f} GiB)...")
|
||||||
|
for member in tqdm(members, desc=" FMA zip extract", unit="file"):
|
||||||
|
zf.extract(member, dest_dir)
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
download_with_fallbacks() {
|
||||||
|
local output="$1"
|
||||||
|
shift
|
||||||
|
local urls=( "$@" )
|
||||||
|
local rc=1
|
||||||
|
|
||||||
|
for url in "${urls[@]}" ; do
|
||||||
|
for attempt in 1 2 3 4 ; do
|
||||||
|
curl -sfL "${url}" -o "${output}" && [ -s "${output}" ] && return 0
|
||||||
|
rc=$?
|
||||||
|
rm -f "${output}" || :
|
||||||
|
if [ "${attempt}" -lt 4 ] ; then
|
||||||
|
echo " Retry ${attempt}/3 after download failure"
|
||||||
|
sleep $(( attempt * 2 ))
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
return "${rc}"
|
||||||
|
}
|
||||||
|
|
||||||
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
||||||
actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || :
|
actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || :
|
||||||
write_filecount=false
|
write_filecount=false
|
||||||
@@ -92,13 +141,16 @@ else
|
|||||||
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||||
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
||||||
echo " Downloading ${AUDIO_ZIPFILE}"
|
echo " Downloading ${AUDIO_ZIPFILE}"
|
||||||
curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}"
|
download_with_fallbacks "${AUDIO_ZIP}" "${AUDIO_URLS[@]}" || {
|
||||||
|
echo " Failed to download ${AUDIO_ZIPFILE} from all configured sources." >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -rf "${AUDIO_DIR}" || :
|
rm -rf "${AUDIO_DIR}" || :
|
||||||
mkdir "${AUDIO_DIR}"
|
mkdir "${AUDIO_DIR}"
|
||||||
echo " Unzipping ${AUDIO_ZIPFILE}"
|
echo " Extracting ${AUDIO_ZIPFILE}"
|
||||||
unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}"
|
extract_zip_with_python "${AUDIO_ZIP}" "${AUDIO_DIR}"
|
||||||
fi
|
fi
|
||||||
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
@@ -128,4 +180,3 @@ fi
|
|||||||
|
|
||||||
echo " FMA complete"
|
echo " FMA complete"
|
||||||
exit 0
|
exit 0
|
||||||
|
|
||||||
|
|||||||
@@ -242,29 +242,7 @@ if [ ! -s "${MODEL_FILE}.json" ] ; then
|
|||||||
curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json"
|
curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# --- Dutch ONNX voices (single-speaker, used with --language=nl) ---
|
echo " Non-English Piper voices will be downloaded on demand for the selected language."
|
||||||
# Working Dutch voices: pim, ronnie (nl_NL) and nathalie (nl_BE).
|
|
||||||
# nl_NL-mls-medium is intentionally excluded (known Piper issue: outputs gibberish).
|
|
||||||
HF_VOICES="https://huggingface.co/rhasspy/piper-voices/resolve/main"
|
|
||||||
declare -a NL_VOICES=(
|
|
||||||
"nl/nl_NL/pim/medium/nl_NL-pim-medium"
|
|
||||||
"nl/nl_NL/ronnie/medium/nl_NL-ronnie-medium"
|
|
||||||
"nl/nl_BE/nathalie/medium/nl_BE-nathalie-medium"
|
|
||||||
)
|
|
||||||
echo " ===== Checking Dutch Piper voices ====="
|
|
||||||
for voice_path in "${NL_VOICES[@]}" ; do
|
|
||||||
voice_name="$(basename "${voice_path}")"
|
|
||||||
onnx_file="${VOICES_DIR}/${voice_name}.onnx"
|
|
||||||
json_file="${VOICES_DIR}/${voice_name}.onnx.json"
|
|
||||||
if [ ! -f "${onnx_file}" ] ; then
|
|
||||||
echo " Downloading ${voice_name}.onnx"
|
|
||||||
curl -sfL "${HF_VOICES}/${voice_path}.onnx?download=true" -o "${onnx_file}"
|
|
||||||
fi
|
|
||||||
if [ ! -f "${json_file}" ] ; then
|
|
||||||
echo " Downloading ${voice_name}.onnx.json"
|
|
||||||
curl -sfL "${HF_VOICES}/${voice_path}.onnx.json?download=true" -o "${json_file}"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu=""
|
${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu=""
|
||||||
echo " ===== Installing onnxruntime${onnxgpu} ====="
|
echo " ===== Installing onnxruntime${onnxgpu} ====="
|
||||||
|
|||||||
@@ -317,6 +317,7 @@ fi
|
|||||||
|
|
||||||
TRAINING_DONE="false"
|
TRAINING_DONE="false"
|
||||||
|
|
||||||
|
echo "🏋️ Starting model training and TFLite export (this is the longest stage)…"
|
||||||
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
|
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
|
||||||
echo "✅ Training complete (GPU path)."
|
echo "✅ Training complete (GPU path)."
|
||||||
TRAINING_DONE="true"
|
TRAINING_DONE="true"
|
||||||
@@ -386,12 +387,24 @@ if [ "${TRAINING_DONE}" != "true" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||||
|
calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json"
|
||||||
|
|
||||||
if [ ! -f "${source_path}" ] ; then
|
if [ ! -f "${source_path}" ] ; then
|
||||||
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
echo "🎯 Calibrating detector settings for on-device use…"
|
||||||
|
if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \
|
||||||
|
--training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \
|
||||||
|
--model "${source_path}" \
|
||||||
|
--output "${calibration_path}"; then
|
||||||
|
echo "✅ Detector calibration complete."
|
||||||
|
else
|
||||||
|
echo "⚠️ Detector calibration failed; packaging with default detector settings."
|
||||||
|
rm -f "${calibration_path}" || :
|
||||||
|
fi
|
||||||
|
|
||||||
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
||||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
||||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
||||||
@@ -404,24 +417,49 @@ tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
|||||||
cp "${source_path}" "${tflite_path}"
|
cp "${source_path}" "${tflite_path}"
|
||||||
|
|
||||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||||
cat <<-EOF > "${json_path}"
|
export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}"
|
||||||
{
|
echo "📦 Packaging final model artifacts…"
|
||||||
|
"${PYTHON_BIN:-python}" - <<'PY'
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
json_path = Path(os.environ["JSON_PATH"])
|
||||||
|
calibration_path = Path(os.environ.get("CALIBRATION_PATH", ""))
|
||||||
|
language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower()
|
||||||
|
probability_cutoff = 0.97
|
||||||
|
sliding_window_size = 5
|
||||||
|
|
||||||
|
if calibration_path.exists():
|
||||||
|
try:
|
||||||
|
calibration = json.loads(calibration_path.read_text(encoding="utf-8"))
|
||||||
|
probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff))
|
||||||
|
sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size))
|
||||||
|
print(
|
||||||
|
f"🎯 Using calibrated detector settings: "
|
||||||
|
f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.")
|
||||||
|
|
||||||
|
meta = {
|
||||||
"type": "micro",
|
"type": "micro",
|
||||||
"wake_word": "${WAKE_WORD_TITLE}",
|
"wake_word": os.environ["WAKE_WORD_TITLE"],
|
||||||
"author": "Tater Totterson",
|
"author": "Tater Totterson",
|
||||||
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
||||||
"model": "${tflite_filename}",
|
"model": os.environ["TFLITE_FILENAME"],
|
||||||
"trained_languages": ["en"],
|
"trained_languages": [language],
|
||||||
"version": 2,
|
"version": 2,
|
||||||
"micro": {
|
"micro": {
|
||||||
"probability_cutoff": 0.97,
|
"probability_cutoff": round(probability_cutoff, 2),
|
||||||
"sliding_window_size": 5,
|
"sliding_window_size": sliding_window_size,
|
||||||
"feature_step_size": 10,
|
"feature_step_size": 10,
|
||||||
"tensor_arena_size": 30000,
|
"tensor_arena_size": 30000,
|
||||||
"minimum_esphome_version": "2024.7.0"
|
"minimum_esphome_version": "2024.7.0",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
EOF
|
json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8")
|
||||||
|
PY
|
||||||
|
|
||||||
echo "Name: ${WAKE_WORD_TITLE}"
|
echo "Name: ${WAKE_WORD_TITLE}"
|
||||||
echo "Model: ${tflite_path}"
|
echo "Model: ${tflite_path}"
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/
|
|||||||
# Root-level entrypoints
|
# Root-level entrypoints
|
||||||
COPY --chown=root:root --chmod=0755 \
|
COPY --chown=root:root --chmod=0755 \
|
||||||
train_wake_word \
|
train_wake_word \
|
||||||
run_recorder.sh \
|
run.sh \
|
||||||
trainer_server.py \
|
trainer_server.py \
|
||||||
requirements.txt \
|
requirements.txt \
|
||||||
/root/mww-scripts/
|
/root/mww-scripts/
|
||||||
@@ -37,4 +37,4 @@ RUN chmod -R a+x /root/mww-scripts/cli
|
|||||||
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
||||||
|
|
||||||
# trainer server
|
# trainer server
|
||||||
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"]
|
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run.sh"]
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ DATA_DIR="${DATA_DIR:-/data}"
|
|||||||
HOST="${REC_HOST:-0.0.0.0}"
|
HOST="${REC_HOST:-0.0.0.0}"
|
||||||
PORT="${REC_PORT:-8888}"
|
PORT="${REC_PORT:-8888}"
|
||||||
|
|
||||||
# Keep recorder deps separate from training venv
|
# Keep trainer UI deps separate from the training venv
|
||||||
VENV_DIR="${DATA_DIR}/.recorder-venv"
|
VENV_DIR="${DATA_DIR}/.recorder-venv"
|
||||||
PY="${VENV_DIR}/bin/python"
|
PY="${VENV_DIR}/bin/python"
|
||||||
PIP="${PY} -m pip"
|
PIP="${PY} -m pip"
|
||||||
@@ -42,6 +42,12 @@ PIPER_VOICES_ROOT_URL = os.environ.get(
|
|||||||
"https://huggingface.co/rhasspy/piper-voices/resolve/main",
|
"https://huggingface.co/rhasspy/piper-voices/resolve/main",
|
||||||
)
|
)
|
||||||
PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900"))
|
PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900"))
|
||||||
|
PIPER_CATALOG_CACHE_FILE = Path(
|
||||||
|
os.environ.get(
|
||||||
|
"PIPER_CATALOG_CACHE_FILE",
|
||||||
|
str(ROOT_DIR / ".cache" / "piper_voices_catalog.json"),
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
|
||||||
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
||||||
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
||||||
@@ -177,6 +183,27 @@ def _fetch_piper_catalog() -> Optional[Dict[str, Any]]:
|
|||||||
return data if isinstance(data, dict) else None
|
return data if isinstance(data, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _read_cached_piper_catalog_file() -> Optional[Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
if not PIPER_CATALOG_CACHE_FILE.exists():
|
||||||
|
return None
|
||||||
|
data = json.loads(PIPER_CATALOG_CACHE_FILE.read_text(encoding="utf-8"))
|
||||||
|
return data if isinstance(data, dict) else None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_cached_piper_catalog_file(data: Dict[str, Any]):
|
||||||
|
try:
|
||||||
|
PIPER_CATALOG_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
PIPER_CATALOG_CACHE_FILE.write_text(
|
||||||
|
json.dumps(data, ensure_ascii=True),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
with PIPER_CATALOG_LOCK:
|
with PIPER_CATALOG_LOCK:
|
||||||
@@ -185,6 +212,8 @@ def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
|||||||
if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS:
|
if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS:
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
|
disk_cached = _read_cached_piper_catalog_file()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fresh = _fetch_piper_catalog()
|
fresh = _fetch_piper_catalog()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -194,9 +223,15 @@ def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
|||||||
if fresh is not None:
|
if fresh is not None:
|
||||||
PIPER_CATALOG_CACHE["entries"] = fresh
|
PIPER_CATALOG_CACHE["entries"] = fresh
|
||||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||||
|
_write_cached_piper_catalog_file(fresh)
|
||||||
return fresh
|
return fresh
|
||||||
if PIPER_CATALOG_CACHE.get("entries") is None:
|
if PIPER_CATALOG_CACHE.get("entries") is not None:
|
||||||
PIPER_CATALOG_CACHE["entries"] = {}
|
return PIPER_CATALOG_CACHE.get("entries")
|
||||||
|
if disk_cached is not None:
|
||||||
|
PIPER_CATALOG_CACHE["entries"] = disk_cached
|
||||||
|
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||||
|
return disk_cached
|
||||||
|
PIPER_CATALOG_CACHE["entries"] = {}
|
||||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||||
return PIPER_CATALOG_CACHE.get("entries")
|
return PIPER_CATALOG_CACHE.get("entries")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user