mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee6ff6e9d5 | ||
|
|
251b0280b6 | ||
|
|
dd2bdda431 | ||
|
|
9d8e0afe1b | ||
|
|
318a4ad3b5 | ||
|
|
51cbf6fd90 | ||
|
|
6e7396455a | ||
|
|
2da9f7a686 | ||
|
|
7b028e4420 | ||
|
|
b3d9f0e369 | ||
|
|
240ca7682e | ||
|
|
18e5fcd000 |
322
README.md
322
README.md
@@ -1,25 +1,15 @@
|
||||
<div align="center">
|
||||
<h1>🎙️ microWakeWord Nvidia Trainer & Recorder</h1>
|
||||
<img width="1002" height="593" alt="Screenshot 2026-01-18 at 8 13 35 AM" src="https://github.com/user-attachments/assets/e1411d8a-8638-4df8-992b-09a46c6e5ddc" />
|
||||
<h1>microWakeWord NVIDIA Docker Trainer UI</h1>
|
||||
<img width="800" alt="microWakeWord NVIDIA trainer screenshot" src="https://github.com/user-attachments/assets/694f4cb7-e4d8-4e2b-80ec-b40fb41cbfff" />
|
||||
</div>
|
||||
|
||||
Train **microWakeWord** detection models using a simple **web-based recorder + trainer UI**, packaged in a Docker container.
|
||||
Train custom microWakeWord models in Docker with NVIDIA/CUDA acceleration, generated Piper samples, device-captured samples, reviewed false-wake negatives, live training logs, and ESPHome firmware flashing.
|
||||
|
||||
No Jupyter notebooks required. No manual cell execution. Just record your voice (optional) and train.
|
||||
Real samples come from device-captured wake audio, close misses, or manual uploads. Every saved sample is normalized to `16 kHz / mono / 16-bit PCM WAV` before training.
|
||||
|
||||
---
|
||||
|
||||
<img width="100" height="44" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/87351bed-3321-4a43-924f-fecf2e4e700f" />
|
||||
|
||||
**microWakeWord_Trainer-Nvidia** is available in the **Unraid Community Apps** store.
|
||||
Install directly from the Unraid App Store with a one-click template.
|
||||
|
||||
---
|
||||
|
||||
<img width="100" height="56" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/bf959585-ae13-4b4d-ae62-4202a850d35a" />
|
||||
|
||||
|
||||
### Pull the Docker Image
|
||||
## Docker Image
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/tatertotterson/microwakeword:latest
|
||||
@@ -27,143 +17,227 @@ docker pull ghcr.io/tatertotterson/microwakeword:latest
|
||||
|
||||
---
|
||||
|
||||
### Run the Container
|
||||
## Run The Container
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--gpus all \
|
||||
-p 8888:8888 \
|
||||
--network host \
|
||||
-e REC_PORT=8789 \
|
||||
-v $(pwd):/data \
|
||||
ghcr.io/tatertotterson/microwakeword:latest
|
||||
```
|
||||
|
||||
**What these flags do:**
|
||||
- `--gpus all` → Enables GPU acceleration
|
||||
- `-p 8888:8888` → Exposes the Recorder + Trainer WebUI
|
||||
- `-v $(pwd):/data` → Persists all models, datasets, and cache
|
||||
The flags:
|
||||
|
||||
---
|
||||
- `--gpus all` enables GPU acceleration.
|
||||
- `--network host` lets the container receive mDNS/zeroconf traffic for ESPHome auto-detect.
|
||||
- `-e REC_PORT=8789` sets the trainer web UI and captured-audio port. Change this value if `8789` is already in use.
|
||||
- `-v $(pwd):/data` persists models, downloaded voices, datasets, samples, and firmware caches.
|
||||
|
||||
### Open the Recorder WebUI
|
||||
Host networking is recommended for the Firmware tab's mDNS device discovery. Manual IP flashing and captured-audio uploads can still work without host networking if the trainer port is reachable, but auto-detect may not see devices from Docker bridge networking.
|
||||
|
||||
Open your browser and go to:
|
||||
|
||||
👉 **http://localhost:8888**
|
||||
|
||||
You’ll see the **microWakeWord Recorder & Trainer UI**.
|
||||
|
||||
---
|
||||
|
||||
## 🎤 Recording Voice Samples (Optional)
|
||||
|
||||
Personal voice recordings are **optional**.
|
||||
|
||||
- You may **record your own voice** for better accuracy
|
||||
- Or simply **click “Train” without recording anything**
|
||||
|
||||
If no recordings are present, training will proceed using **synthetic TTS samples only**.
|
||||
|
||||
### Remote systems (important)
|
||||
If you are running this on a **remote PC / server**, browser-based recording will not work unless:
|
||||
- You use a **reverse proxy** (HTTPS + mic permissions), **or**
|
||||
- You access the UI via **localhost** on the same machine
|
||||
|
||||
Training itself works fine remotely — only recording requires local microphone access.
|
||||
|
||||
---
|
||||
|
||||
### 🎙️ Recording Flow
|
||||
|
||||
1. Enter your wake word
|
||||
2. Test pronunciation with **Test TTS**
|
||||
3. Choose:
|
||||
- Number of speakers (e.g. family members)
|
||||
- Takes per speaker (default: 10)
|
||||
4. Click **Begin recording**
|
||||
5. Speak naturally — recording:
|
||||
- Starts when you talk
|
||||
- Stops automatically after silence
|
||||
6. Repeat for each speaker
|
||||
|
||||
Files are saved automatically to:
|
||||
|
||||
```
|
||||
personal_samples/
|
||||
speaker01_take01.wav
|
||||
speaker01_take02.wav
|
||||
speaker02_take01.wav
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧠 Training Behavior (Important Notes)
|
||||
|
||||
### ⏬ First training run
|
||||
The **first time you click Train**, the system will download **large training datasets** (background noise, speech corpora, etc.).
|
||||
|
||||
- This can take **several minutes**
|
||||
- This happens **only once**
|
||||
- Data is cached inside `/data`
|
||||
|
||||
You **will NOT need to download these again** unless you delete `/data`.
|
||||
|
||||
---
|
||||
|
||||
### 🔁 Re-training is safe and incremental
|
||||
|
||||
- You can train **multiple wake words** back-to-back
|
||||
- You do **NOT** need to clear any folders between runs
|
||||
- Old models are preserved in timestamped output directories
|
||||
- All required cleanup and reuse logic is handled automatically
|
||||
|
||||
---
|
||||
|
||||
## 📦 Output Files
|
||||
|
||||
When training completes, you’ll get:
|
||||
- `<wake_word>.tflite` – quantized streaming model
|
||||
- `<wake_word>.json` – ESPHome-compatible metadata
|
||||
|
||||
Both are saved under:
|
||||
Open:
|
||||
|
||||
```text
|
||||
/data/output/
|
||||
http://localhost:8789
|
||||
```
|
||||
|
||||
Each run is placed in its own timestamped folder.
|
||||
If you change `REC_PORT`, open that port instead and use the same port in the ESPHome `Trainer App URL`.
|
||||
|
||||
---
|
||||
|
||||
## 🎤 Optional: Personal Voice Samples (Advanced)
|
||||
## What The UI Does
|
||||
|
||||
If you record personal samples:
|
||||
- They are automatically augmented
|
||||
- They are **up-weighted during training**
|
||||
- This significantly improves real-world accuracy
|
||||
|
||||
No configuration required — detection is automatic.
|
||||
- `Trainer` starts a wake-word session, shows positive/negative sample counts, and launches training.
|
||||
- `Captured Audio` reviews clips sent by ESPHome sats, including wake hits, close misses, and false wakes.
|
||||
- `Samples` plays, removes, clears, and manually imports personal or negative samples.
|
||||
- `Firmware` builds the latest `microWakeWords` ESPHome YAMLs from GitHub and flashes VoicePE or Satellite1 over OTA.
|
||||
- Popup consoles show colorized training and firmware logs while long-running jobs are active.
|
||||
|
||||
---
|
||||
|
||||
## 🔄 Resetting Everything (Optional)
|
||||
## Captured Audio Workflow
|
||||
|
||||
If you want a **completely clean slate**:
|
||||
To collect samples from a sat, flash it with the Tater firmware from [TaterTotterson/microWakeWords](https://github.com/TaterTotterson/microWakeWords). The `Firmware` tab can build and flash the VoicePE or Satellite1 YAMLs directly from that repo.
|
||||
|
||||
Delete the /data folder
|
||||
After flashing, the device exposes ESPHome entities for capture setup:
|
||||
|
||||
Then restart the container.
|
||||
- `Capture Wake Audio` toggles upload of wake-word triggers.
|
||||
- `Capture Close Misses` toggles upload of near misses.
|
||||
- `Trainer App URL` sets the trainer address, for example `http://<trainer-ip>:8789`.
|
||||
|
||||
⚠️ This will:
|
||||
- Remove cached datasets
|
||||
- Require re-downloading training data
|
||||
- Delete trained models
|
||||
ESPHome devices can send raw captured audio to:
|
||||
|
||||
```text
|
||||
/api/upload_captured_audio_raw
|
||||
```
|
||||
|
||||
Keep the training app running and reachable at the `Trainer App URL` while capture is enabled. The sats upload clips live; if the app is stopped or the URL is wrong, captured audio will not be saved.
|
||||
|
||||
In the `Captured Audio` tab:
|
||||
|
||||
- play each clip from the inbox
|
||||
- mark good wake-word clips as `This is good`
|
||||
- mark bad triggers as `False wake`
|
||||
- discard clips that should not be used
|
||||
|
||||
Approved clips move into:
|
||||
|
||||
```text
|
||||
/data/personal_samples/
|
||||
```
|
||||
|
||||
False wakes move into:
|
||||
|
||||
```text
|
||||
/data/negative_samples/
|
||||
```
|
||||
|
||||
Captured audio is boosted for easier playback in the UI, then kept in the correct training format.
|
||||
|
||||
---
|
||||
|
||||
## 🙌 Credits
|
||||
## Samples
|
||||
|
||||
Built on top of the excellent
|
||||
**https://github.com/kahrendt/microWakeWord**
|
||||
The `Samples` tab is the sample library.
|
||||
|
||||
Huge thanks to the original authors ❤️
|
||||
- `Personal` samples are positive examples of the wake word.
|
||||
- `Negative` samples are reviewed false wakes or hard negatives.
|
||||
- Both can be played back and removed one at a time.
|
||||
- Manual upload is available here as an optional seed path.
|
||||
|
||||
Accepted manual upload formats include:
|
||||
|
||||
- WAV
|
||||
- MP3
|
||||
- M4A
|
||||
- FLAC
|
||||
- OGG
|
||||
- AAC
|
||||
- OPUS
|
||||
- WEBM
|
||||
|
||||
Uploads are validated or converted with `ffmpeg` into:
|
||||
|
||||
```text
|
||||
16 kHz / mono / 16-bit PCM WAV
|
||||
```
|
||||
|
||||
Starting a new session does not clear samples. Use the clear buttons in `Samples` if you want to remove saved personal or negative clips.
|
||||
|
||||
---
|
||||
|
||||
## Training Flow
|
||||
|
||||
1. Enter the wake phrase in `Trainer`.
|
||||
2. Choose the language.
|
||||
3. Optionally test pronunciation with `Test TTS`.
|
||||
4. Review the positive and negative sample counts.
|
||||
5. Click `Start training`.
|
||||
6. Watch the popup training console.
|
||||
|
||||
Personal samples are optional. Training can run with zero personal samples after confirmation, using generated TTS samples and the stock negative datasets.
|
||||
|
||||
Reviewed negative samples are converted into `/data/work/reviewed_negative_features/` and inserted into the training YAML as a hard-negative feature set when present.
|
||||
|
||||
---
|
||||
|
||||
## Language Support
|
||||
|
||||
The language picker is dynamic.
|
||||
|
||||
- `en` is always available.
|
||||
- English keeps the existing dedicated generator model path.
|
||||
- Non-English languages are discovered from the Piper voices catalog and any local Piper voice metadata.
|
||||
- When a non-English language is selected, the trainer downloads all voices for that selected language only.
|
||||
- Already-downloaded voices are reused.
|
||||
- It does not download every language up front.
|
||||
|
||||
If the upstream Piper catalog is unavailable, already-installed local voices are used when available.
|
||||
|
||||
---
|
||||
|
||||
## Dataset Behavior
|
||||
|
||||
The first training run downloads and prepares missing training assets into `/data`, including:
|
||||
|
||||
- Piper voices for the selected language
|
||||
- negative datasets and background data
|
||||
- the Python training environment
|
||||
- generated samples and augmented feature caches
|
||||
|
||||
After those assets are prepared, later runs reuse the local copies unless the mounted `/data` contents are deleted.
|
||||
|
||||
---
|
||||
|
||||
## Firmware Flashing
|
||||
|
||||
The `Firmware` tab builds and flashes Tater firmware for supported ESPHome sats.
|
||||
|
||||
- Downloads the latest firmware YAML templates from `TaterTotterson/microWakeWords` on GitHub.
|
||||
- Lets you choose `VoicePE` or `Satellite1`.
|
||||
- Auto-detects ESPHome devices with mDNS when the container is running with host networking.
|
||||
- Allows manual IP or hostname entry if discovery does not find the device.
|
||||
- Saves firmware form values so you do not re-enter sounds and URLs every run.
|
||||
- Lists locally trained wake words from `/data/trained_wake_words/` for easy model selection.
|
||||
- Builds with ESPHome and flashes OTA.
|
||||
- Streams ESPHome output in a colorized firmware console.
|
||||
|
||||
Firmware YAMLs are intentionally pulled from GitHub each time. There is no local fallback path in the trainer UI.
|
||||
|
||||
---
|
||||
|
||||
## Output Files
|
||||
|
||||
Successful runs produce timestamped training output folders such as:
|
||||
|
||||
```text
|
||||
/data/output/<timestamp>-<wake_word>-<samples>-<steps>/<wake_word>.tflite
|
||||
/data/output/<timestamp>-<wake_word>-<samples>-<steps>/<wake_word>.json
|
||||
```
|
||||
|
||||
The trainer also syncs firmware-ready artifacts into:
|
||||
|
||||
```text
|
||||
/data/trained_wake_words/<wake_word>.tflite
|
||||
/data/trained_wake_words/<wake_word>.json
|
||||
```
|
||||
|
||||
The firmware tab uses `/data/trained_wake_words/` to populate the wake-word dropdown.
|
||||
|
||||
---
|
||||
|
||||
## Resetting Everything
|
||||
|
||||
If you want a clean slate, stop the container and remove the contents of the mounted `/data` directory.
|
||||
|
||||
That removes:
|
||||
|
||||
- personal samples
|
||||
- negative samples
|
||||
- captured inbox clips
|
||||
- downloaded Piper voices
|
||||
- cached datasets
|
||||
- training environments
|
||||
- trained models
|
||||
- firmware build caches
|
||||
|
||||
---
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Personal samples are optional.
|
||||
- Negative samples are optional but useful for reducing false wakes.
|
||||
- The UI server is `trainer_server.py`.
|
||||
- The launcher is `run.sh`.
|
||||
- Firmware capture settings live on the ESPHome device and can be toggled from the device entities after flashing.
|
||||
|
||||
---
|
||||
|
||||
## Credits
|
||||
|
||||
Built on top of:
|
||||
|
||||
- [microWakeWord](https://github.com/kahrendt/microWakeWord)
|
||||
- [piper-sample-generator](https://github.com/rhasspy/piper-sample-generator)
|
||||
|
||||
405
cli/calibrate_detector.py
Normal file
405
cli/calibrate_detector.py
Normal file
@@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Choose detector metadata that better matches the trained model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from microwakeword.data import FeatureHandler
|
||||
from microwakeword.inference import Model
|
||||
|
||||
|
||||
DEFAULT_WINDOW_SIZES = [3, 4, 5, 6, 7]
|
||||
DEFAULT_TARGET_FAPH = float(os.environ.get("MWW_CALIBRATION_TARGET_FAPH", "1.0"))
|
||||
DEFAULT_COOLDOWN_SLICES = int(os.environ.get("MWW_CALIBRATION_COOLDOWN_SLICES", "25"))
|
||||
DEFAULT_POSITIVE_SKIP_SLICES = int(
|
||||
os.environ.get("MWW_CALIBRATION_POSITIVE_SKIP_SLICES", "25")
|
||||
)
|
||||
DEFAULT_CUTOFF_STEP = float(os.environ.get("MWW_CALIBRATION_CUTOFF_STEP", "0.01"))
|
||||
DEFAULT_CUTOFF_MIN = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MIN", "0.00"))
|
||||
DEFAULT_CUTOFF_MAX = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MAX", "1.00"))
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Calibrate microWakeWord detector metadata from validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-config",
|
||||
default="trained_models/wakeword/training_config.yaml",
|
||||
help="Path to the saved microWakeWord training_config.yaml file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=(
|
||||
"trained_models/wakeword/tflite_stream_state_internal_quant/"
|
||||
"stream_state_internal_quant.tflite"
|
||||
),
|
||||
help="Path to the quantized streaming TFLite model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=(
|
||||
"trained_models/wakeword/tflite_stream_state_internal_quant/"
|
||||
"detection_calibration.json"
|
||||
),
|
||||
help="Where to write the selected detector settings as JSON.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window-sizes",
|
||||
default=",".join(str(value) for value in DEFAULT_WINDOW_SIZES),
|
||||
help="Comma-separated sliding window sizes to evaluate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-faph",
|
||||
type=float,
|
||||
default=DEFAULT_TARGET_FAPH,
|
||||
help="Target ambient false accepts per hour for the selected operating point.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cooldown-slices",
|
||||
type=int,
|
||||
default=DEFAULT_COOLDOWN_SLICES,
|
||||
help="Cooldown slices to use when estimating false accepts per hour.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--positive-skip-slices",
|
||||
type=int,
|
||||
default=DEFAULT_POSITIVE_SKIP_SLICES,
|
||||
help="Initial streaming slices to ignore when scoring positive examples.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cutoff-step",
|
||||
type=float,
|
||||
default=DEFAULT_CUTOFF_STEP,
|
||||
help="Cutoff increment to evaluate between cutoff-min and cutoff-max.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cutoff-min",
|
||||
type=float,
|
||||
default=DEFAULT_CUTOFF_MIN,
|
||||
help="Minimum cutoff to evaluate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cutoff-max",
|
||||
type=float,
|
||||
default=DEFAULT_CUTOFF_MAX,
|
||||
help="Maximum cutoff to evaluate.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_window_sizes(raw: str) -> list[int]:
|
||||
values = []
|
||||
for item in (raw or "").split(","):
|
||||
item = item.strip()
|
||||
if not item:
|
||||
continue
|
||||
value = int(item)
|
||||
if value < 1:
|
||||
raise ValueError("window sizes must be >= 1")
|
||||
values.append(value)
|
||||
if not values:
|
||||
raise ValueError("at least one window size is required")
|
||||
return sorted(set(values))
|
||||
|
||||
|
||||
def _moving_average(values: Sequence[float], window_size: int) -> np.ndarray:
|
||||
array = np.asarray(values, dtype=np.float32)
|
||||
if array.size == 0:
|
||||
return array
|
||||
if window_size <= 1:
|
||||
return array
|
||||
if array.size < window_size:
|
||||
return np.asarray([float(array.mean())], dtype=np.float32)
|
||||
cumsum = np.cumsum(np.insert(array, 0, 0.0))
|
||||
averaged = (cumsum[window_size:] - cumsum[:-window_size]) / float(window_size)
|
||||
return averaged.astype(np.float32)
|
||||
|
||||
|
||||
def _compute_false_accepts_per_hour(
|
||||
probabilities_per_track: Iterable[np.ndarray],
|
||||
cutoffs: np.ndarray,
|
||||
cooldown_slices: int,
|
||||
stride: int,
|
||||
step_seconds: float,
|
||||
) -> tuple[np.ndarray, float]:
|
||||
cutoffs = np.asarray(cutoffs, dtype=np.float32)
|
||||
false_accepts = np.zeros(cutoffs.shape[0], dtype=np.float64)
|
||||
duration_hours = 0.0
|
||||
|
||||
for track_probabilities in probabilities_per_track:
|
||||
if track_probabilities.size == 0:
|
||||
continue
|
||||
duration_hours += (
|
||||
len(track_probabilities) * stride * step_seconds / 3600.0
|
||||
)
|
||||
cooldown = np.full(cutoffs.shape[0], cooldown_slices, dtype=np.int32)
|
||||
for probability in track_probabilities:
|
||||
cooldown = np.maximum(cooldown - 1, 0)
|
||||
accepted = (cooldown == 0) & (probability > cutoffs)
|
||||
false_accepts += accepted.astype(np.float64)
|
||||
cooldown = np.where(accepted, cooldown_slices, cooldown)
|
||||
|
||||
if duration_hours <= 0:
|
||||
return np.full(cutoffs.shape[0], math.inf, dtype=np.float64), 0.0
|
||||
|
||||
return false_accepts / duration_hours, duration_hours
|
||||
|
||||
|
||||
def _select_best_candidate(
|
||||
candidates: list[dict[str, float]],
|
||||
target_faph: float,
|
||||
) -> tuple[dict[str, float], float]:
|
||||
fallback_limits = [
|
||||
target_faph,
|
||||
max(target_faph * 2.0, target_faph + 0.5),
|
||||
max(target_faph * 4.0, 2.0),
|
||||
]
|
||||
|
||||
def tier(candidate: dict[str, float]) -> int:
|
||||
for index, limit in enumerate(fallback_limits):
|
||||
if candidate["false_accepts_per_hour"] <= limit + 1e-9:
|
||||
return index
|
||||
return len(fallback_limits)
|
||||
|
||||
best = min(
|
||||
candidates,
|
||||
key=lambda candidate: (
|
||||
tier(candidate),
|
||||
-candidate["recall"],
|
||||
candidate["false_accepts_per_hour"],
|
||||
abs(candidate["sliding_window_size"] - 5),
|
||||
-candidate["probability_cutoff"],
|
||||
),
|
||||
)
|
||||
|
||||
tier_index = tier(best)
|
||||
if tier_index < len(fallback_limits):
|
||||
return best, fallback_limits[tier_index]
|
||||
return best, float("inf")
|
||||
|
||||
|
||||
def _load_config(config_path: Path) -> dict:
|
||||
with config_path.open("r", encoding="utf-8") as handle:
|
||||
return yaml.load(handle.read(), Loader=yaml.Loader)
|
||||
|
||||
|
||||
def _load_eval_sets(
|
||||
handler: FeatureHandler,
|
||||
config: dict,
|
||||
) -> tuple[str, str, list[np.ndarray], list[np.ndarray]]:
|
||||
for positive_mode, ambient_mode in (
|
||||
("validation", "validation_ambient"),
|
||||
("testing", "testing_ambient"),
|
||||
):
|
||||
positive_tracks, labels, _ = handler.get_data(
|
||||
positive_mode,
|
||||
batch_size=config["batch_size"],
|
||||
features_length=config["spectrogram_length"],
|
||||
truncation_strategy="none",
|
||||
)
|
||||
ambient_tracks, _, _ = handler.get_data(
|
||||
ambient_mode,
|
||||
batch_size=config["batch_size"],
|
||||
features_length=config["spectrogram_length"],
|
||||
truncation_strategy="none",
|
||||
)
|
||||
positives = [
|
||||
np.asarray(track)
|
||||
for track, label in zip(positive_tracks, labels)
|
||||
if bool(label)
|
||||
]
|
||||
ambient = [np.asarray(track) for track in ambient_tracks]
|
||||
if positives and ambient:
|
||||
return positive_mode, ambient_mode, positives, ambient
|
||||
raise RuntimeError(
|
||||
"No suitable validation/testing data was found for detector calibration."
|
||||
)
|
||||
|
||||
|
||||
def _predict_tracks(
|
||||
model: Model,
|
||||
tracks: Sequence[np.ndarray],
|
||||
label: str,
|
||||
) -> list[np.ndarray]:
|
||||
predictions: list[np.ndarray] = []
|
||||
total = len(tracks)
|
||||
print(f"→ Running streaming inference on {total} {label} track(s)")
|
||||
for index, track in enumerate(tracks, start=1):
|
||||
values = np.asarray(model.predict_spectrogram(track), dtype=np.float32)
|
||||
predictions.append(values)
|
||||
if index == total or index % 25 == 0:
|
||||
print(f" {label}: {index}/{total}")
|
||||
return predictions
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
window_sizes = _parse_window_sizes(args.window_sizes)
|
||||
if args.cutoff_step <= 0:
|
||||
raise ValueError("cutoff-step must be > 0")
|
||||
if args.cutoff_max < args.cutoff_min:
|
||||
raise ValueError("cutoff-max must be >= cutoff-min")
|
||||
|
||||
config_path = Path(args.training_config)
|
||||
model_path = Path(args.model)
|
||||
output_path = Path(args.output)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Training config not found: {config_path}")
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Streaming TFLite model not found: {model_path}")
|
||||
|
||||
cutoffs = np.arange(
|
||||
args.cutoff_min,
|
||||
args.cutoff_max + (args.cutoff_step / 2.0),
|
||||
args.cutoff_step,
|
||||
dtype=np.float32,
|
||||
)
|
||||
cutoffs = np.clip(cutoffs, 0.0, 1.0)
|
||||
cutoffs = np.unique(np.round(cutoffs, 4))
|
||||
|
||||
print("===== Detector Calibration =====")
|
||||
print(f"→ Model: {model_path}")
|
||||
print(f"→ Training config: {config_path}")
|
||||
print(
|
||||
f"→ Evaluating window sizes {window_sizes} with target <= "
|
||||
f"{args.target_faph:.2f} false accepts/hour"
|
||||
)
|
||||
|
||||
config = _load_config(config_path)
|
||||
config["flags"] = config.get("flags", {})
|
||||
handler = FeatureHandler(config)
|
||||
|
||||
positive_mode, ambient_mode, positive_tracks, ambient_tracks = _load_eval_sets(
|
||||
handler, config
|
||||
)
|
||||
|
||||
print(
|
||||
f"→ Using {positive_mode} positives ({len(positive_tracks)}) and "
|
||||
f"{ambient_mode} ambient tracks ({len(ambient_tracks)})"
|
||||
)
|
||||
|
||||
model = Model(str(model_path), stride=config["stride"])
|
||||
positive_predictions = _predict_tracks(model, positive_tracks, "positive")
|
||||
ambient_predictions = _predict_tracks(model, ambient_tracks, "ambient")
|
||||
|
||||
candidates: list[dict[str, float]] = []
|
||||
best_by_window: list[dict[str, float]] = []
|
||||
step_seconds = config["window_step_ms"] / 1000.0
|
||||
|
||||
for window_size in window_sizes:
|
||||
ambient_averages = [
|
||||
_moving_average(track, window_size) for track in ambient_predictions
|
||||
]
|
||||
positive_maxima = []
|
||||
for track in positive_predictions:
|
||||
search = (
|
||||
track[args.positive_skip_slices :]
|
||||
if track.size > args.positive_skip_slices
|
||||
else track
|
||||
)
|
||||
averaged = _moving_average(search, window_size)
|
||||
if averaged.size == 0:
|
||||
averaged = _moving_average(track, window_size)
|
||||
positive_maxima.append(float(np.max(averaged)) if averaged.size else 0.0)
|
||||
|
||||
positive_maxima_array = np.asarray(positive_maxima, dtype=np.float32)
|
||||
recall_by_cutoff = np.mean(
|
||||
positive_maxima_array[None, :] > cutoffs[:, None], axis=1
|
||||
)
|
||||
faph_by_cutoff, ambient_hours = _compute_false_accepts_per_hour(
|
||||
ambient_averages,
|
||||
cutoffs,
|
||||
args.cooldown_slices,
|
||||
stride=config["stride"],
|
||||
step_seconds=step_seconds,
|
||||
)
|
||||
|
||||
window_candidates = []
|
||||
for cutoff, recall, faph in zip(cutoffs, recall_by_cutoff, faph_by_cutoff):
|
||||
candidate = {
|
||||
"probability_cutoff": float(round(float(cutoff), 2)),
|
||||
"sliding_window_size": int(window_size),
|
||||
"recall": float(recall),
|
||||
"false_accepts_per_hour": float(faph),
|
||||
"ambient_hours": float(ambient_hours),
|
||||
}
|
||||
candidates.append(candidate)
|
||||
window_candidates.append(candidate)
|
||||
|
||||
best_window, _ = _select_best_candidate(window_candidates, args.target_faph)
|
||||
best_by_window.append(best_window)
|
||||
print(
|
||||
" window={window}: cutoff={cutoff:.2f}; recall={recall:.2%}; "
|
||||
"ambient_faph={faph:.3f}".format(
|
||||
window=window_size,
|
||||
cutoff=best_window["probability_cutoff"],
|
||||
recall=best_window["recall"],
|
||||
faph=best_window["false_accepts_per_hour"],
|
||||
)
|
||||
)
|
||||
|
||||
best, selected_limit = _select_best_candidate(candidates, args.target_faph)
|
||||
if best["false_accepts_per_hour"] > args.target_faph + 1e-9:
|
||||
print(
|
||||
"⚠️ No candidate met the target false accepts/hour budget; "
|
||||
"using the best fallback operating point."
|
||||
)
|
||||
|
||||
print(
|
||||
"✓ Selected cutoff={cutoff:.2f}, window={window}, recall={recall:.2%}, "
|
||||
"ambient_faph={faph:.3f}".format(
|
||||
cutoff=best["probability_cutoff"],
|
||||
window=best["sliding_window_size"],
|
||||
recall=best["recall"],
|
||||
faph=best["false_accepts_per_hour"],
|
||||
)
|
||||
)
|
||||
|
||||
output = {
|
||||
"probability_cutoff": best["probability_cutoff"],
|
||||
"sliding_window_size": best["sliding_window_size"],
|
||||
"target_false_accepts_per_hour": float(args.target_faph),
|
||||
"selected_false_accepts_per_hour_limit": (
|
||||
None if math.isinf(selected_limit) else float(selected_limit)
|
||||
),
|
||||
"selected_metrics": {
|
||||
"recall": round(best["recall"], 6),
|
||||
"false_accepts_per_hour": round(best["false_accepts_per_hour"], 6),
|
||||
"ambient_hours": round(best["ambient_hours"], 6),
|
||||
},
|
||||
"evaluation": {
|
||||
"positive_dataset": positive_mode,
|
||||
"ambient_dataset": ambient_mode,
|
||||
"positive_tracks": len(positive_tracks),
|
||||
"ambient_tracks": len(ambient_tracks),
|
||||
"cooldown_slices": int(args.cooldown_slices),
|
||||
"positive_skip_slices": int(args.positive_skip_slices),
|
||||
"window_sizes": window_sizes,
|
||||
"cutoff_min": round(float(cutoffs[0]), 4),
|
||||
"cutoff_max": round(float(cutoffs[-1]), 4),
|
||||
"cutoff_step": float(args.cutoff_step),
|
||||
},
|
||||
"per_window_best": best_by_window,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(output, indent=2) + "\n", encoding="utf-8")
|
||||
print(f"📝 Wrote calibration to {output_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -130,6 +130,73 @@ print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} fa
|
||||
EOF
|
||||
}
|
||||
|
||||
converter_from_dataset_api() {
|
||||
# shellcheck source=/dev/null
|
||||
source "${DATA_DIR}/.venv/bin/activate"
|
||||
|
||||
python - "${AUDIO16K_DIR}" <<-'EOF'
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
from datasets import load_dataset
|
||||
|
||||
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
x = np.clip(data, -1.0, 1.0)
|
||||
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||
|
||||
audioset_out = Path(sys.argv[1])
|
||||
|
||||
print(" AudioSet FLAC tarballs are unavailable; using Hugging Face datasets API instead.")
|
||||
dataset = load_dataset(
|
||||
"agkphysics/AudioSet",
|
||||
"balanced",
|
||||
split="train",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
audioset_bad = []
|
||||
ok = 0
|
||||
skipped = 0
|
||||
heartbeat_every = 250
|
||||
|
||||
for idx, sample in enumerate(dataset, start=1):
|
||||
try:
|
||||
video_id = str(sample.get("video_id") or f"audioset_{idx:06d}")
|
||||
outfile = audioset_out / f"{video_id}.wav"
|
||||
if outfile.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
audio = sample.get("audio") or {}
|
||||
y = np.asarray(audio.get("array"))
|
||||
sr = int(audio.get("sampling_rate") or 0)
|
||||
if y.size == 0 or sr <= 0:
|
||||
raise ValueError("missing decoded audio")
|
||||
if y.ndim > 1:
|
||||
y = np.mean(y, axis=-1)
|
||||
if sr != 16000:
|
||||
y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=16000)
|
||||
if y.size == 0:
|
||||
raise ValueError("empty audio")
|
||||
write_wav(outfile, y, 16000)
|
||||
ok += 1
|
||||
except Exception as exc:
|
||||
audioset_bad.append(f"{sample.get('video_id', idx)}:{exc}")
|
||||
|
||||
if idx == 1 or (idx % heartbeat_every) == 0:
|
||||
print(f" AudioSet API progress: {idx} clips processed (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})")
|
||||
|
||||
if audioset_bad:
|
||||
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
|
||||
|
||||
print(f" AudioSet complete via datasets API ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed)")
|
||||
EOF
|
||||
}
|
||||
|
||||
expected_filecount=$(get_total_filecount filecounts)
|
||||
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||
write_filecount=false
|
||||
@@ -139,7 +206,10 @@ if [ "${actual_filecount}" -ne 0 ] ; then
|
||||
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
|
||||
else
|
||||
dl=$(find_rev)
|
||||
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; }
|
||||
if [ -z "$dl" ] ; then
|
||||
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
||||
converter_from_dataset_api
|
||||
else
|
||||
rev=${dl%%,*}
|
||||
pattern=${dl##*,}
|
||||
|
||||
@@ -174,6 +244,7 @@ else
|
||||
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
|
||||
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if ${write_filecount} ; then
|
||||
|
||||
@@ -27,8 +27,11 @@ cd "${DATA_DIR}/training_datasets"
|
||||
|
||||
echo "***** Checking FMA *****"
|
||||
|
||||
AUDIO_URL="https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
||||
AUDIO_ZIPFILE="fma_xs.zip"
|
||||
AUDIO_URLS=(
|
||||
"https://os.unil.cloud.switch.ch/fma/fma_small.zip"
|
||||
"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
||||
)
|
||||
AUDIO_ZIPFILE="fma_small.zip"
|
||||
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||
AUDIO_DIR="fma"
|
||||
mkdir -p "${AUDIO_DIR}" || :
|
||||
@@ -81,6 +84,52 @@ EOF
|
||||
|
||||
}
|
||||
|
||||
extract_zip_with_python() {
|
||||
local zip_path="$1"
|
||||
local dest_dir="$2"
|
||||
|
||||
"${DATA_DIR}/.venv/bin/python" - "${zip_path}" "${dest_dir}" <<-'EOF'
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
zip_path = Path(sys.argv[1])
|
||||
dest_dir = Path(sys.argv[2])
|
||||
|
||||
if (not zip_path.exists()) or zip_path.stat().st_size == 0:
|
||||
raise SystemExit(f"Archive missing or empty: {zip_path}")
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
members = zf.infolist()
|
||||
size_gb = zip_path.stat().st_size / (1024 ** 3)
|
||||
print(f" Extracting {zip_path.name} ({len(members)} entries, {size_gb:.1f} GiB)...")
|
||||
for member in tqdm(members, desc=" FMA zip extract", unit="file"):
|
||||
zf.extract(member, dest_dir)
|
||||
EOF
|
||||
}
|
||||
|
||||
download_with_fallbacks() {
|
||||
local output="$1"
|
||||
shift
|
||||
local urls=( "$@" )
|
||||
local rc=1
|
||||
|
||||
for url in "${urls[@]}" ; do
|
||||
for attempt in 1 2 3 4 ; do
|
||||
curl -sfL "${url}" -o "${output}" && [ -s "${output}" ] && return 0
|
||||
rc=$?
|
||||
rm -f "${output}" || :
|
||||
if [ "${attempt}" -lt 4 ] ; then
|
||||
echo " Retry ${attempt}/3 after download failure"
|
||||
sleep $(( attempt * 2 ))
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
return "${rc}"
|
||||
}
|
||||
|
||||
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
||||
actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || :
|
||||
write_filecount=false
|
||||
@@ -92,13 +141,16 @@ else
|
||||
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
||||
echo " Downloading ${AUDIO_ZIPFILE}"
|
||||
curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}"
|
||||
download_with_fallbacks "${AUDIO_ZIP}" "${AUDIO_URLS[@]}" || {
|
||||
echo " Failed to download ${AUDIO_ZIPFILE} from all configured sources." >&2
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
rm -rf "${AUDIO_DIR}" || :
|
||||
mkdir "${AUDIO_DIR}"
|
||||
echo " Unzipping ${AUDIO_ZIPFILE}"
|
||||
unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}"
|
||||
echo " Extracting ${AUDIO_ZIPFILE}"
|
||||
extract_zip_with_python "${AUDIO_ZIP}" "${AUDIO_DIR}"
|
||||
fi
|
||||
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||
@@ -128,4 +180,3 @@ fi
|
||||
|
||||
echo " FMA complete"
|
||||
exit 0
|
||||
|
||||
|
||||
@@ -242,29 +242,7 @@ if [ ! -s "${MODEL_FILE}.json" ] ; then
|
||||
curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json"
|
||||
fi
|
||||
|
||||
# --- Dutch ONNX voices (single-speaker, used with --language=nl) ---
|
||||
# Working Dutch voices: pim, ronnie (nl_NL) and nathalie (nl_BE).
|
||||
# nl_NL-mls-medium is intentionally excluded (known Piper issue: outputs gibberish).
|
||||
HF_VOICES="https://huggingface.co/rhasspy/piper-voices/resolve/main"
|
||||
declare -a NL_VOICES=(
|
||||
"nl/nl_NL/pim/medium/nl_NL-pim-medium"
|
||||
"nl/nl_NL/ronnie/medium/nl_NL-ronnie-medium"
|
||||
"nl/nl_BE/nathalie/medium/nl_BE-nathalie-medium"
|
||||
)
|
||||
echo " ===== Checking Dutch Piper voices ====="
|
||||
for voice_path in "${NL_VOICES[@]}" ; do
|
||||
voice_name="$(basename "${voice_path}")"
|
||||
onnx_file="${VOICES_DIR}/${voice_name}.onnx"
|
||||
json_file="${VOICES_DIR}/${voice_name}.onnx.json"
|
||||
if [ ! -f "${onnx_file}" ] ; then
|
||||
echo " Downloading ${voice_name}.onnx"
|
||||
curl -sfL "${HF_VOICES}/${voice_path}.onnx?download=true" -o "${onnx_file}"
|
||||
fi
|
||||
if [ ! -f "${json_file}" ] ; then
|
||||
echo " Downloading ${voice_name}.onnx.json"
|
||||
curl -sfL "${HF_VOICES}/${voice_path}.onnx.json?download=true" -o "${json_file}"
|
||||
fi
|
||||
done
|
||||
echo " Non-English Piper voices will be downloaded on demand for the selected language."
|
||||
|
||||
${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu=""
|
||||
echo " ===== Installing onnxruntime${onnxgpu} ====="
|
||||
|
||||
@@ -18,6 +18,8 @@ parser.add_argument("--output-dir", type=str, help="Wake word output dir. Defaul
|
||||
# Personal inputs/outputs (NEW)
|
||||
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
|
||||
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
|
||||
parser.add_argument("--negative-dir", type=str, help="Reviewed negative WAV dir. Default: <data-dir>/negative_samples", required=False)
|
||||
parser.add_argument("--negative-output-dir", type=str, help="Reviewed negative features output dir. Default: <data-dir>/work/reviewed_negative_features", required=False)
|
||||
|
||||
# Dataset dirs
|
||||
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
||||
@@ -57,6 +59,17 @@ if not args.personal_output_dir:
|
||||
else:
|
||||
args.personal_output_dir = os.path.realpath(args.personal_output_dir)
|
||||
|
||||
# Reviewed negative defaults
|
||||
if not args.negative_dir:
|
||||
args.negative_dir = os.path.join(args.data_dir, "negative_samples")
|
||||
else:
|
||||
args.negative_dir = os.path.realpath(args.negative_dir)
|
||||
|
||||
if not args.negative_output_dir:
|
||||
args.negative_output_dir = os.path.join(work_dir, "reviewed_negative_features")
|
||||
else:
|
||||
args.negative_output_dir = os.path.realpath(args.negative_output_dir)
|
||||
|
||||
# Dataset defaults
|
||||
if not args.mit_rirs_16k_dir:
|
||||
args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
|
||||
@@ -205,7 +218,7 @@ def bind_wav_generator(clips_obj: Clips, wav_dir: str):
|
||||
|
||||
clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)
|
||||
|
||||
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
|
||||
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str, *, remove_silence: bool = True):
|
||||
files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
|
||||
if not files:
|
||||
print(f"ℹ️ No WAVs found for {label} in: {input_wav_dir} (skipping)")
|
||||
@@ -218,7 +231,7 @@ def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
|
||||
input_directory=input_wav_dir,
|
||||
file_pattern="*.wav",
|
||||
max_clip_duration_s=5,
|
||||
remove_silence=True,
|
||||
remove_silence=remove_silence,
|
||||
random_split_seed=10,
|
||||
split_count=0.1,
|
||||
)
|
||||
@@ -263,9 +276,12 @@ def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
|
||||
# Wake word generated/TTS features (existing behavior)
|
||||
generate_feature_set(args.input_dir, args.output_dir, "generated")
|
||||
|
||||
# Personal features (NEW)
|
||||
# Personal features
|
||||
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")
|
||||
|
||||
# Reviewed false-positive / hard-negative features
|
||||
generate_feature_set(args.negative_dir, args.negative_output_dir, "reviewed negatives", remove_silence=False)
|
||||
|
||||
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
et = END_TIME - START_TIME
|
||||
print(f"\n{'=' * 80}")
|
||||
|
||||
@@ -111,6 +111,16 @@ else
|
||||
echo "ℹ️ No personal features found at ${PERSONAL_FEATURES_DIR}/training (continuing without personal weighting)"
|
||||
fi
|
||||
|
||||
# Reviewed false-positive features are optional hard negatives.
|
||||
REVIEWED_NEGATIVE_FEATURES_DIR="${WORK_DIR}/reviewed_negative_features"
|
||||
HAS_REVIEWED_NEGATIVE="false"
|
||||
if [ -d "${REVIEWED_NEGATIVE_FEATURES_DIR}/training" ] ; then
|
||||
HAS_REVIEWED_NEGATIVE="true"
|
||||
echo "✅ Found reviewed negative features: ${REVIEWED_NEGATIVE_FEATURES_DIR}/training (will weight as hard negatives)"
|
||||
else
|
||||
echo "ℹ️ No reviewed negative features found at ${REVIEWED_NEGATIVE_FEATURES_DIR}/training (continuing with stock negatives)"
|
||||
fi
|
||||
|
||||
cd "${WORK_DIR}"
|
||||
|
||||
echo "===== Starting ${TRAINING_STEPS} training steps ====="
|
||||
@@ -133,6 +143,7 @@ features:
|
||||
truth: true
|
||||
type: mmap
|
||||
__PERSONAL_FEATURE_MARKER__
|
||||
__REVIEWED_NEGATIVE_FEATURE_MARKER__
|
||||
- features_dir: __NEG_SPEECH__
|
||||
penalty_weight: 1.0
|
||||
sampling_weight: 12.0
|
||||
@@ -208,6 +219,22 @@ else
|
||||
sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}"
|
||||
fi
|
||||
|
||||
# Insert/remove reviewed hard-negative block
|
||||
if [ "${HAS_REVIEWED_NEGATIVE}" = "true" ]; then
|
||||
reviewed_negative_block="$(cat <<EOF
|
||||
- features_dir: ${REVIEWED_NEGATIVE_FEATURES_DIR}
|
||||
penalty_weight: 1.25
|
||||
sampling_weight: 8.0
|
||||
truncation_strategy: random
|
||||
truth: false
|
||||
type: mmap
|
||||
EOF
|
||||
)"
|
||||
perl -0777 -i -pe "s#__REVIEWED_NEGATIVE_FEATURE_MARKER__#${reviewed_negative_block}#g" "${YAML_PATH}"
|
||||
else
|
||||
sed -i -e "/__REVIEWED_NEGATIVE_FEATURE_MARKER__/d" "${YAML_PATH}"
|
||||
fi
|
||||
|
||||
echo " Wrote training_parameters.yaml"
|
||||
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
||||
|
||||
@@ -317,6 +344,7 @@ fi
|
||||
|
||||
TRAINING_DONE="false"
|
||||
|
||||
echo "🏋️ Starting model training and TFLite export (this is the longest stage)…"
|
||||
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
|
||||
echo "✅ Training complete (GPU path)."
|
||||
TRAINING_DONE="true"
|
||||
@@ -386,12 +414,24 @@ if [ "${TRAINING_DONE}" != "true" ]; then
|
||||
fi
|
||||
|
||||
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||
calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json"
|
||||
|
||||
if [ ! -f "${source_path}" ] ; then
|
||||
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🎯 Calibrating detector settings for on-device use…"
|
||||
if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \
|
||||
--training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \
|
||||
--model "${source_path}" \
|
||||
--output "${calibration_path}"; then
|
||||
echo "✅ Detector calibration complete."
|
||||
else
|
||||
echo "⚠️ Detector calibration failed; packaging with default detector settings."
|
||||
rm -f "${calibration_path}" || :
|
||||
fi
|
||||
|
||||
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
||||
@@ -404,24 +444,49 @@ tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||
cp "${source_path}" "${tflite_path}"
|
||||
|
||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||
cat <<-EOF > "${json_path}"
|
||||
{
|
||||
export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}"
|
||||
echo "📦 Packaging final model artifacts…"
|
||||
"${PYTHON_BIN:-python}" - <<'PY'
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
json_path = Path(os.environ["JSON_PATH"])
|
||||
calibration_path = Path(os.environ.get("CALIBRATION_PATH", ""))
|
||||
language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower()
|
||||
probability_cutoff = 0.97
|
||||
sliding_window_size = 5
|
||||
|
||||
if calibration_path.exists():
|
||||
try:
|
||||
calibration = json.loads(calibration_path.read_text(encoding="utf-8"))
|
||||
probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff))
|
||||
sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size))
|
||||
print(
|
||||
f"🎯 Using calibrated detector settings: "
|
||||
f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}"
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.")
|
||||
|
||||
meta = {
|
||||
"type": "micro",
|
||||
"wake_word": "${WAKE_WORD_TITLE}",
|
||||
"wake_word": os.environ["WAKE_WORD_TITLE"],
|
||||
"author": "Tater Totterson",
|
||||
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
||||
"model": "${tflite_filename}",
|
||||
"trained_languages": ["en"],
|
||||
"model": os.environ["TFLITE_FILENAME"],
|
||||
"trained_languages": [language],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"sliding_window_size": 5,
|
||||
"probability_cutoff": round(probability_cutoff, 2),
|
||||
"sliding_window_size": sliding_window_size,
|
||||
"feature_step_size": 10,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
}
|
||||
"minimum_esphome_version": "2024.7.0",
|
||||
},
|
||||
}
|
||||
EOF
|
||||
json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8")
|
||||
PY
|
||||
|
||||
echo "Name: ${WAKE_WORD_TITLE}"
|
||||
echo "Model: ${tflite_path}"
|
||||
|
||||
@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
# System deps
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \
|
||||
git wget curl unzip ca-certificates nano less \
|
||||
git wget curl unzip patch ca-certificates nano less \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /data
|
||||
|
||||
@@ -22,7 +22,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/
|
||||
# Root-level entrypoints
|
||||
COPY --chown=root:root --chmod=0755 \
|
||||
train_wake_word \
|
||||
run_recorder.sh \
|
||||
run.sh \
|
||||
trainer_server.py \
|
||||
requirements.txt \
|
||||
/root/mww-scripts/
|
||||
@@ -37,4 +37,4 @@ RUN chmod -R a+x /root/mww-scripts/cli
|
||||
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
||||
|
||||
# trainer server
|
||||
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"]
|
||||
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run.sh"]
|
||||
|
||||
@@ -6,9 +6,9 @@ ROOTDIR="$(dirname "$(realpath "$0")")"
|
||||
# Training convention
|
||||
DATA_DIR="${DATA_DIR:-/data}"
|
||||
HOST="${REC_HOST:-0.0.0.0}"
|
||||
PORT="${REC_PORT:-8888}"
|
||||
PORT="${REC_PORT:-8789}"
|
||||
|
||||
# Keep recorder deps separate from training venv
|
||||
# Keep trainer UI deps separate from the training venv
|
||||
VENV_DIR="${DATA_DIR}/.recorder-venv"
|
||||
PY="${VENV_DIR}/bin/python"
|
||||
PIP="${PY} -m pip"
|
||||
@@ -17,6 +17,7 @@ PIN_FILE="${VENV_DIR}/.pinned_installed"
|
||||
FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
|
||||
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
|
||||
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
|
||||
ESPHOME_VERSION="${REC_ESPHOME_VERSION:-2026.4.0}"
|
||||
|
||||
echo "microWakeWord Trainer UI (Docker)"
|
||||
echo "-> ROOTDIR: ${ROOTDIR}"
|
||||
@@ -42,10 +43,27 @@ if [[ ! -f "${PIN_FILE}" ]]; then
|
||||
${PIP} install \
|
||||
"fastapi==${FASTAPI_VERSION}" \
|
||||
"uvicorn[standard]==${UVICORN_VERSION}" \
|
||||
"python-multipart==${PY_MULTIPART_VERSION}"
|
||||
"python-multipart==${PY_MULTIPART_VERSION}" \
|
||||
"esphome==${ESPHOME_VERSION}"
|
||||
touch "${PIN_FILE}"
|
||||
else
|
||||
echo "Reusing existing trainer UI venv (no upgrades)"
|
||||
if ! "${PY}" - "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1
|
||||
import importlib.metadata
|
||||
import sys
|
||||
|
||||
expected = sys.argv[1]
|
||||
installed = importlib.metadata.version("esphome")
|
||||
raise SystemExit(0 if installed == expected else 1)
|
||||
PY
|
||||
then
|
||||
echo "Firmware tab dependencies missing or stale; installing ESPHome firmware dependencies"
|
||||
${PIP} install \
|
||||
"fastapi==${FASTAPI_VERSION}" \
|
||||
"uvicorn[standard]==${UVICORN_VERSION}" \
|
||||
"python-multipart==${PY_MULTIPART_VERSION}" \
|
||||
"esphome==${ESPHOME_VERSION}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# -----------------------------
|
||||
@@ -54,6 +72,9 @@ fi
|
||||
export DATA_DIR="${DATA_DIR}"
|
||||
export STATIC_DIR="${ROOTDIR}/static"
|
||||
export PERSONAL_DIR="${DATA_DIR}/personal_samples"
|
||||
export CAPTURED_DIR="${DATA_DIR}/captured_audio"
|
||||
export NEGATIVE_DIR="${DATA_DIR}/negative_samples"
|
||||
export TRAINED_WAKE_WORDS_DIR="${DATA_DIR}/trained_wake_words"
|
||||
|
||||
# IMPORTANT: leave training venv creation to /api/train inside trainer_server.py
|
||||
# but still set TRAIN_CMD so the server knows how to invoke training once ready
|
||||
1959
static/index.html
1959
static/index.html
File diff suppressed because it is too large
Load Diff
@@ -150,6 +150,7 @@ if ${CLEANUP_WORK_DIR} ; then
|
||||
"${DATA_DIR}/work/wake_word_samples" \
|
||||
"${DATA_DIR}/work/wake_word_samples_augmented" \
|
||||
"${DATA_DIR}/work/personal_augmented_features" \
|
||||
"${DATA_DIR}/work/reviewed_negative_features" \
|
||||
"${DATA_DIR}/work/last_wake_word" || :
|
||||
fi
|
||||
|
||||
|
||||
2019
trainer_server.py
2019
trainer_server.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user