30 Commits
v2 ... main

Author SHA1 Message Date
MasterPhooey
874f273d0b Bump ESPHome pin to 2026.5.1 2026-06-03 10:21:23 -05:00
MasterPhooey
04249f414d Add new ReSpeaker firmware flasher templates 2026-05-19 15:49:57 -05:00
MasterPhooey
6a0d60d569 Add live wake word URL card 2026-05-19 07:42:20 -05:00
MasterPhooey
8df17599c2 Update Tater repo logo 2026-05-16 09:44:35 -05:00
MasterPhooey
280e8f8de4 Update README logo 2026-05-16 07:59:22 -05:00
Tater Totterson
b582a6cade Update Docker image name in workflow 2026-05-16 01:03:11 -05:00
MasterPhooey
196ab8c0e7 Add VAD trimming and Docker publishing 2026-05-16 00:32:05 -05:00
MasterPhooey
134f607bef 2026.4.3 2026-05-03 09:31:02 -05:00
MasterPhooey
4a9e2f2cde 2026.4.3 2026-05-03 07:55:07 -05:00
MasterPhooey
7c246856df cache update 2026-05-02 09:27:11 -05:00
MasterPhooey
3705dabc09 sat1 cache fix 2026-05-01 21:34:17 -05:00
MasterPhooey
1dcf48209f wake sound 2026-05-01 18:31:13 -05:00
MasterPhooey
4f44bef8d5 build cache 2026-05-01 18:03:37 -05:00
MasterPhooey
98fa879db1 wake sound 2026-05-01 17:01:15 -05:00
MasterPhooey
dfac549430 wake sound 2026-05-01 16:49:57 -05:00
MasterPhooey
775a78326b firmware url fixes 2026-05-01 16:24:36 -05:00
MasterPhooey
429be4cc67 502 2026-04-25 12:48:06 -05:00
Tater Totterson
2e6179ec32 Enhance README with images and link
Added additional images and a link to the README for better presentation.
2026-04-25 10:05:57 -05:00
MasterPhooey
ee6ff6e9d5 micro-opus fix 2026-04-25 08:46:34 -05:00
MasterPhooey
251b0280b6 ui fix 2026-04-25 08:22:26 -05:00
MasterPhooey
dd2bdda431 update readme 2026-04-25 08:01:13 -05:00
MasterPhooey
9d8e0afe1b update readme 2026-04-25 07:54:57 -05:00
MasterPhooey
318a4ad3b5 Sat Samples 2026-04-25 07:29:23 -05:00
Tater Totterson
51cbf6fd90 Delete __pycache__ directory 2026-04-18 11:38:33 -05:00
MasterPhooey
6e7396455a Automatic Calibration 2026-04-18 09:02:05 -05:00
Tater Totterson
2da9f7a686 Add screenshot to README for better visualization 2026-04-14 23:13:43 -05:00
MasterPhooey
7b028e4420 update readme 2026-04-14 23:07:58 -05:00
Tater Totterson
b3d9f0e369 Remove 'Recorder' from project title in README 2026-04-14 23:04:46 -05:00
Tater Totterson
240ca7682e Adjust image width in README.md
Updated image dimensions in README for better display.
2026-04-14 23:04:25 -05:00
Tater Totterson
18e5fcd000 Change image source in README.md
Updated the image source in the README file.
2026-04-14 23:02:49 -05:00
15 changed files with 6335 additions and 435 deletions

48
.github/workflows/docker-publish.yml vendored Normal file
View File

@@ -0,0 +1,48 @@
name: Publish Docker Image
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: read
packages: write
concurrency:
group: docker-publish-${{ github.ref }}
cancel-in-progress: true
env:
REGISTRY: ghcr.io
IMAGE_NAME: tatertotterson/microwakeword
jobs:
docker:
name: Docker image
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push image
uses: docker/build-push-action@v6
with:
context: .
file: dockerfile
platforms: linux/amd64
push: true
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
cache-from: type=gha,scope=mww-trainer-nvidia-docker
cache-to: type=gha,mode=max,scope=mww-trainer-nvidia-docker

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
personal_samples/* personal_samples/*
data/ data/
trim_history/
.DS_Store .DS_Store

326
README.md
View File

@@ -1,25 +1,19 @@
<div align="center"> <div align="center">
<h1>🎙️ microWakeWord Nvidia Trainer & Recorder</h1> <a href="https://taterassistant.com">
<img width="1002" height="593" alt="Screenshot 2026-01-18 at 8 13 35AM" src="https://github.com/user-attachments/assets/e1411d8a-8638-4df8-992b-09a46c6e5ddc" /> <img src="images/tater-repo-logo.png" alt="microWakeWord Trainer" width="460"/>
</a>
</div> </div>
<h3 align="center">
<a href="https://taterassistant.com">taterassistant.com</a>
</h3>
Train **microWakeWord** detection models using a simple **web-based recorder + trainer UI**, packaged in a Docker container. Train custom microWakeWord models in Docker with NVIDIA/CUDA acceleration, generated Piper samples, device-captured samples, reviewed false-wake negatives, live training logs, and ESPHome firmware flashing.
No Jupyter notebooks required. No manual cell execution. Just record your voice (optional) and train. Real samples come from device-captured wake audio, close misses, or manual uploads. Every saved sample is normalized to `16 kHz / mono / 16-bit PCM WAV` before training.
--- ---
<img width="100" height="44" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/87351bed-3321-4a43-924f-fecf2e4e700f" /> ## Docker Image
**microWakeWord_Trainer-Nvidia** is available in the **Unraid Community Apps** store.
Install directly from the Unraid App Store with a one-click template.
---
<img width="100" height="56" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/bf959585-ae13-4b4d-ae62-4202a850d35a" />
### Pull the Docker Image
```bash ```bash
docker pull ghcr.io/tatertotterson/microwakeword:latest docker pull ghcr.io/tatertotterson/microwakeword:latest
@@ -27,143 +21,227 @@ docker pull ghcr.io/tatertotterson/microwakeword:latest
--- ---
### Run the Container ## Run The Container
```bash ```bash
docker run -d \ docker run -d \
--gpus all \ --gpus all \
-p 8888:8888 \ --network host \
-e REC_PORT=8789 \
-v $(pwd):/data \ -v $(pwd):/data \
ghcr.io/tatertotterson/microwakeword:latest ghcr.io/tatertotterson/microwakeword:latest
``` ```
**What these flags do:** The flags:
- `--gpus all` → Enables GPU acceleration
- `-p 8888:8888` → Exposes the Recorder + Trainer WebUI
- `-v $(pwd):/data` → Persists all models, datasets, and cache
--- - `--gpus all` enables GPU acceleration.
- `--network host` lets the container receive mDNS/zeroconf traffic for ESPHome auto-detect.
- `-e REC_PORT=8789` sets the trainer web UI and captured-audio port. Change this value if `8789` is already in use.
- `-v $(pwd):/data` persists models, downloaded voices, datasets, samples, and firmware caches.
### Open the Recorder WebUI Host networking is recommended for the Firmware tab's mDNS device discovery. Manual IP flashing and captured-audio uploads can still work without host networking if the trainer port is reachable, but auto-detect may not see devices from Docker bridge networking.
Open your browser and go to: Open:
👉 **http://localhost:8888**
Youll see the **microWakeWord Recorder & Trainer UI**.
---
## 🎤 Recording Voice Samples (Optional)
Personal voice recordings are **optional**.
- You may **record your own voice** for better accuracy
- Or simply **click “Train” without recording anything**
If no recordings are present, training will proceed using **synthetic TTS samples only**.
### Remote systems (important)
If you are running this on a **remote PC / server**, browser-based recording will not work unless:
- You use a **reverse proxy** (HTTPS + mic permissions), **or**
- You access the UI via **localhost** on the same machine
Training itself works fine remotely — only recording requires local microphone access.
---
### 🎙️ Recording Flow
1. Enter your wake word
2. Test pronunciation with **Test TTS**
3. Choose:
- Number of speakers (e.g. family members)
- Takes per speaker (default: 10)
4. Click **Begin recording**
5. Speak naturally — recording:
- Starts when you talk
- Stops automatically after silence
6. Repeat for each speaker
Files are saved automatically to:
```
personal_samples/
speaker01_take01.wav
speaker01_take02.wav
speaker02_take01.wav
...
```
---
## 🧠 Training Behavior (Important Notes)
### ⏬ First training run
The **first time you click Train**, the system will download **large training datasets** (background noise, speech corpora, etc.).
- This can take **several minutes**
- This happens **only once**
- Data is cached inside `/data`
You **will NOT need to download these again** unless you delete `/data`.
---
### 🔁 Re-training is safe and incremental
- You can train **multiple wake words** back-to-back
- You do **NOT** need to clear any folders between runs
- Old models are preserved in timestamped output directories
- All required cleanup and reuse logic is handled automatically
---
## 📦 Output Files
When training completes, youll get:
- `<wake_word>.tflite` quantized streaming model
- `<wake_word>.json` ESPHome-compatible metadata
Both are saved under:
```text ```text
/data/output/ http://localhost:8789
``` ```
Each run is placed in its own timestamped folder. If you change `REC_PORT`, open that port instead and use the same port in the ESPHome `Trainer App URL`.
--- ---
## 🎤 Optional: Personal Voice Samples (Advanced) ## What The UI Does
If you record personal samples: - `Trainer` starts a wake-word session, shows positive/negative sample counts, and launches training.
- They are automatically augmented - `Captured Audio` reviews clips sent by ESPHome sats, including wake hits, close misses, and false wakes.
- They are **up-weighted during training** - `Samples` plays, removes, clears, and manually imports personal or negative samples.
- This significantly improves real-world accuracy - `Firmware` builds the latest `microWakeWords` ESPHome YAMLs from GitHub and flashes VoicePE or Satellite1 over OTA.
- Popup consoles show colorized training and firmware logs while long-running jobs are active.
No configuration required — detection is automatic.
--- ---
## 🔄 Resetting Everything (Optional) ## Captured Audio Workflow
If you want a **completely clean slate**: To collect samples from a sat, flash it with the Tater firmware from [TaterTotterson/microWakeWords](https://github.com/TaterTotterson/microWakeWords). The `Firmware` tab can build and flash the VoicePE or Satellite1 YAMLs directly from that repo.
Delete the /data folder After flashing, the device exposes ESPHome entities for capture setup:
Then restart the container. - `Capture Wake Audio` toggles upload of wake-word triggers.
- `Capture Close Misses` toggles upload of near misses.
- `Trainer App URL` sets the trainer address, for example `http://<trainer-ip>:8789`.
⚠️ This will: ESPHome devices can send raw captured audio to:
- Remove cached datasets
- Require re-downloading training data ```text
- Delete trained models /api/upload_captured_audio_raw
```
Keep the training app running and reachable at the `Trainer App URL` while capture is enabled. The sats upload clips live; if the app is stopped or the URL is wrong, captured audio will not be saved.
In the `Captured Audio` tab:
- play each clip from the inbox
- mark good wake-word clips as `This is good`
- mark bad triggers as `False wake`
- discard clips that should not be used
Approved clips move into:
```text
/data/personal_samples/
```
False wakes move into:
```text
/data/negative_samples/
```
Captured audio is boosted for easier playback in the UI, then kept in the correct training format.
--- ---
## 🙌 Credits ## Samples
Built on top of the excellent The `Samples` tab is the sample library.
**https://github.com/kahrendt/microWakeWord**
Huge thanks to the original authors ❤️ - `Personal` samples are positive examples of the wake word.
- `Negative` samples are reviewed false wakes or hard negatives.
- Both can be played back and removed one at a time.
- Manual upload is available here as an optional seed path.
Accepted manual upload formats include:
- WAV
- MP3
- M4A
- FLAC
- OGG
- AAC
- OPUS
- WEBM
Uploads are validated or converted with `ffmpeg` into:
```text
16 kHz / mono / 16-bit PCM WAV
```
Starting a new session does not clear samples. Use the clear buttons in `Samples` if you want to remove saved personal or negative clips.
---
## Training Flow
1. Enter the wake phrase in `Trainer`.
2. Choose the language.
3. Optionally test pronunciation with `Test TTS`.
4. Review the positive and negative sample counts.
5. Click `Start training`.
6. Watch the popup training console.
Personal samples are optional. Training can run with zero personal samples after confirmation, using generated TTS samples and the stock negative datasets.
Reviewed negative samples are converted into `/data/work/reviewed_negative_features/` and inserted into the training YAML as a hard-negative feature set when present.
---
## Language Support
The language picker is dynamic.
- `en` is always available.
- English keeps the existing dedicated generator model path.
- Non-English languages are discovered from the Piper voices catalog and any local Piper voice metadata.
- When a non-English language is selected, the trainer downloads all voices for that selected language only.
- Already-downloaded voices are reused.
- It does not download every language up front.
If the upstream Piper catalog is unavailable, already-installed local voices are used when available.
---
## Dataset Behavior
The first training run downloads and prepares missing training assets into `/data`, including:
- Piper voices for the selected language
- negative datasets and background data
- the Python training environment
- generated samples and augmented feature caches
After those assets are prepared, later runs reuse the local copies unless the mounted `/data` contents are deleted.
---
## Firmware Flashing
The `Firmware` tab builds and flashes Tater firmware for supported ESPHome sats.
- Downloads the latest firmware YAML templates from `TaterTotterson/microWakeWords` on GitHub.
- Lets you choose `VoicePE` or `Satellite1`.
- Auto-detects ESPHome devices with mDNS when the container is running with host networking.
- Allows manual IP or hostname entry if discovery does not find the device.
- Saves firmware form values so you do not re-enter sounds and URLs every run.
- Lists locally trained wake words from `/data/trained_wake_words/` for easy model selection.
- Builds with ESPHome and flashes OTA.
- Streams ESPHome output in a colorized firmware console.
Firmware YAMLs are intentionally pulled from GitHub each time. There is no local fallback path in the trainer UI.
---
## Output Files
Successful runs produce timestamped training output folders such as:
```text
/data/output/<timestamp>-<wake_word>-<samples>-<steps>/<wake_word>.tflite
/data/output/<timestamp>-<wake_word>-<samples>-<steps>/<wake_word>.json
```
The trainer also syncs firmware-ready artifacts into:
```text
/data/trained_wake_words/<wake_word>.tflite
/data/trained_wake_words/<wake_word>.json
```
The firmware tab uses `/data/trained_wake_words/` to populate the wake-word dropdown.
---
## Resetting Everything
If you want a clean slate, stop the container and remove the contents of the mounted `/data` directory.
That removes:
- personal samples
- negative samples
- captured inbox clips
- downloaded Piper voices
- cached datasets
- training environments
- trained models
- firmware build caches
---
## Important Notes
- Personal samples are optional.
- Negative samples are optional but useful for reducing false wakes.
- The UI server is `trainer_server.py`.
- The launcher is `run.sh`.
- Firmware capture settings live on the ESPHome device and can be toggled from the device entities after flashing.
---
## Credits
Built on top of:
- [microWakeWord](https://github.com/kahrendt/microWakeWord)
- [piper-sample-generator](https://github.com/rhasspy/piper-sample-generator)

405
cli/calibrate_detector.py Normal file
View File

@@ -0,0 +1,405 @@
#!/usr/bin/env python3
"""Choose detector metadata that better matches the trained model."""
from __future__ import annotations
import argparse
import json
import math
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable, Sequence
import numpy as np
import yaml
from microwakeword.data import FeatureHandler
from microwakeword.inference import Model
DEFAULT_WINDOW_SIZES = [3, 4, 5, 6, 7]
DEFAULT_TARGET_FAPH = float(os.environ.get("MWW_CALIBRATION_TARGET_FAPH", "1.0"))
DEFAULT_COOLDOWN_SLICES = int(os.environ.get("MWW_CALIBRATION_COOLDOWN_SLICES", "25"))
DEFAULT_POSITIVE_SKIP_SLICES = int(
os.environ.get("MWW_CALIBRATION_POSITIVE_SKIP_SLICES", "25")
)
DEFAULT_CUTOFF_STEP = float(os.environ.get("MWW_CALIBRATION_CUTOFF_STEP", "0.01"))
DEFAULT_CUTOFF_MIN = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MIN", "0.00"))
DEFAULT_CUTOFF_MAX = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MAX", "1.00"))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Calibrate microWakeWord detector metadata from validation data."
)
parser.add_argument(
"--training-config",
default="trained_models/wakeword/training_config.yaml",
help="Path to the saved microWakeWord training_config.yaml file.",
)
parser.add_argument(
"--model",
default=(
"trained_models/wakeword/tflite_stream_state_internal_quant/"
"stream_state_internal_quant.tflite"
),
help="Path to the quantized streaming TFLite model.",
)
parser.add_argument(
"--output",
default=(
"trained_models/wakeword/tflite_stream_state_internal_quant/"
"detection_calibration.json"
),
help="Where to write the selected detector settings as JSON.",
)
parser.add_argument(
"--window-sizes",
default=",".join(str(value) for value in DEFAULT_WINDOW_SIZES),
help="Comma-separated sliding window sizes to evaluate.",
)
parser.add_argument(
"--target-faph",
type=float,
default=DEFAULT_TARGET_FAPH,
help="Target ambient false accepts per hour for the selected operating point.",
)
parser.add_argument(
"--cooldown-slices",
type=int,
default=DEFAULT_COOLDOWN_SLICES,
help="Cooldown slices to use when estimating false accepts per hour.",
)
parser.add_argument(
"--positive-skip-slices",
type=int,
default=DEFAULT_POSITIVE_SKIP_SLICES,
help="Initial streaming slices to ignore when scoring positive examples.",
)
parser.add_argument(
"--cutoff-step",
type=float,
default=DEFAULT_CUTOFF_STEP,
help="Cutoff increment to evaluate between cutoff-min and cutoff-max.",
)
parser.add_argument(
"--cutoff-min",
type=float,
default=DEFAULT_CUTOFF_MIN,
help="Minimum cutoff to evaluate.",
)
parser.add_argument(
"--cutoff-max",
type=float,
default=DEFAULT_CUTOFF_MAX,
help="Maximum cutoff to evaluate.",
)
return parser.parse_args()
def _parse_window_sizes(raw: str) -> list[int]:
values = []
for item in (raw or "").split(","):
item = item.strip()
if not item:
continue
value = int(item)
if value < 1:
raise ValueError("window sizes must be >= 1")
values.append(value)
if not values:
raise ValueError("at least one window size is required")
return sorted(set(values))
def _moving_average(values: Sequence[float], window_size: int) -> np.ndarray:
array = np.asarray(values, dtype=np.float32)
if array.size == 0:
return array
if window_size <= 1:
return array
if array.size < window_size:
return np.asarray([float(array.mean())], dtype=np.float32)
cumsum = np.cumsum(np.insert(array, 0, 0.0))
averaged = (cumsum[window_size:] - cumsum[:-window_size]) / float(window_size)
return averaged.astype(np.float32)
def _compute_false_accepts_per_hour(
probabilities_per_track: Iterable[np.ndarray],
cutoffs: np.ndarray,
cooldown_slices: int,
stride: int,
step_seconds: float,
) -> tuple[np.ndarray, float]:
cutoffs = np.asarray(cutoffs, dtype=np.float32)
false_accepts = np.zeros(cutoffs.shape[0], dtype=np.float64)
duration_hours = 0.0
for track_probabilities in probabilities_per_track:
if track_probabilities.size == 0:
continue
duration_hours += (
len(track_probabilities) * stride * step_seconds / 3600.0
)
cooldown = np.full(cutoffs.shape[0], cooldown_slices, dtype=np.int32)
for probability in track_probabilities:
cooldown = np.maximum(cooldown - 1, 0)
accepted = (cooldown == 0) & (probability > cutoffs)
false_accepts += accepted.astype(np.float64)
cooldown = np.where(accepted, cooldown_slices, cooldown)
if duration_hours <= 0:
return np.full(cutoffs.shape[0], math.inf, dtype=np.float64), 0.0
return false_accepts / duration_hours, duration_hours
def _select_best_candidate(
candidates: list[dict[str, float]],
target_faph: float,
) -> tuple[dict[str, float], float]:
fallback_limits = [
target_faph,
max(target_faph * 2.0, target_faph + 0.5),
max(target_faph * 4.0, 2.0),
]
def tier(candidate: dict[str, float]) -> int:
for index, limit in enumerate(fallback_limits):
if candidate["false_accepts_per_hour"] <= limit + 1e-9:
return index
return len(fallback_limits)
best = min(
candidates,
key=lambda candidate: (
tier(candidate),
-candidate["recall"],
candidate["false_accepts_per_hour"],
abs(candidate["sliding_window_size"] - 5),
-candidate["probability_cutoff"],
),
)
tier_index = tier(best)
if tier_index < len(fallback_limits):
return best, fallback_limits[tier_index]
return best, float("inf")
def _load_config(config_path: Path) -> dict:
with config_path.open("r", encoding="utf-8") as handle:
return yaml.load(handle.read(), Loader=yaml.Loader)
def _load_eval_sets(
handler: FeatureHandler,
config: dict,
) -> tuple[str, str, list[np.ndarray], list[np.ndarray]]:
for positive_mode, ambient_mode in (
("validation", "validation_ambient"),
("testing", "testing_ambient"),
):
positive_tracks, labels, _ = handler.get_data(
positive_mode,
batch_size=config["batch_size"],
features_length=config["spectrogram_length"],
truncation_strategy="none",
)
ambient_tracks, _, _ = handler.get_data(
ambient_mode,
batch_size=config["batch_size"],
features_length=config["spectrogram_length"],
truncation_strategy="none",
)
positives = [
np.asarray(track)
for track, label in zip(positive_tracks, labels)
if bool(label)
]
ambient = [np.asarray(track) for track in ambient_tracks]
if positives and ambient:
return positive_mode, ambient_mode, positives, ambient
raise RuntimeError(
"No suitable validation/testing data was found for detector calibration."
)
def _predict_tracks(
model: Model,
tracks: Sequence[np.ndarray],
label: str,
) -> list[np.ndarray]:
predictions: list[np.ndarray] = []
total = len(tracks)
print(f"→ Running streaming inference on {total} {label} track(s)")
for index, track in enumerate(tracks, start=1):
values = np.asarray(model.predict_spectrogram(track), dtype=np.float32)
predictions.append(values)
if index == total or index % 25 == 0:
print(f" {label}: {index}/{total}")
return predictions
def main() -> int:
args = parse_args()
window_sizes = _parse_window_sizes(args.window_sizes)
if args.cutoff_step <= 0:
raise ValueError("cutoff-step must be > 0")
if args.cutoff_max < args.cutoff_min:
raise ValueError("cutoff-max must be >= cutoff-min")
config_path = Path(args.training_config)
model_path = Path(args.model)
output_path = Path(args.output)
if not config_path.exists():
raise FileNotFoundError(f"Training config not found: {config_path}")
if not model_path.exists():
raise FileNotFoundError(f"Streaming TFLite model not found: {model_path}")
cutoffs = np.arange(
args.cutoff_min,
args.cutoff_max + (args.cutoff_step / 2.0),
args.cutoff_step,
dtype=np.float32,
)
cutoffs = np.clip(cutoffs, 0.0, 1.0)
cutoffs = np.unique(np.round(cutoffs, 4))
print("===== Detector Calibration =====")
print(f"→ Model: {model_path}")
print(f"→ Training config: {config_path}")
print(
f"→ Evaluating window sizes {window_sizes} with target <= "
f"{args.target_faph:.2f} false accepts/hour"
)
config = _load_config(config_path)
config["flags"] = config.get("flags", {})
handler = FeatureHandler(config)
positive_mode, ambient_mode, positive_tracks, ambient_tracks = _load_eval_sets(
handler, config
)
print(
f"→ Using {positive_mode} positives ({len(positive_tracks)}) and "
f"{ambient_mode} ambient tracks ({len(ambient_tracks)})"
)
model = Model(str(model_path), stride=config["stride"])
positive_predictions = _predict_tracks(model, positive_tracks, "positive")
ambient_predictions = _predict_tracks(model, ambient_tracks, "ambient")
candidates: list[dict[str, float]] = []
best_by_window: list[dict[str, float]] = []
step_seconds = config["window_step_ms"] / 1000.0
for window_size in window_sizes:
ambient_averages = [
_moving_average(track, window_size) for track in ambient_predictions
]
positive_maxima = []
for track in positive_predictions:
search = (
track[args.positive_skip_slices :]
if track.size > args.positive_skip_slices
else track
)
averaged = _moving_average(search, window_size)
if averaged.size == 0:
averaged = _moving_average(track, window_size)
positive_maxima.append(float(np.max(averaged)) if averaged.size else 0.0)
positive_maxima_array = np.asarray(positive_maxima, dtype=np.float32)
recall_by_cutoff = np.mean(
positive_maxima_array[None, :] > cutoffs[:, None], axis=1
)
faph_by_cutoff, ambient_hours = _compute_false_accepts_per_hour(
ambient_averages,
cutoffs,
args.cooldown_slices,
stride=config["stride"],
step_seconds=step_seconds,
)
window_candidates = []
for cutoff, recall, faph in zip(cutoffs, recall_by_cutoff, faph_by_cutoff):
candidate = {
"probability_cutoff": float(round(float(cutoff), 2)),
"sliding_window_size": int(window_size),
"recall": float(recall),
"false_accepts_per_hour": float(faph),
"ambient_hours": float(ambient_hours),
}
candidates.append(candidate)
window_candidates.append(candidate)
best_window, _ = _select_best_candidate(window_candidates, args.target_faph)
best_by_window.append(best_window)
print(
" window={window}: cutoff={cutoff:.2f}; recall={recall:.2%}; "
"ambient_faph={faph:.3f}".format(
window=window_size,
cutoff=best_window["probability_cutoff"],
recall=best_window["recall"],
faph=best_window["false_accepts_per_hour"],
)
)
best, selected_limit = _select_best_candidate(candidates, args.target_faph)
if best["false_accepts_per_hour"] > args.target_faph + 1e-9:
print(
"⚠️ No candidate met the target false accepts/hour budget; "
"using the best fallback operating point."
)
print(
"✓ Selected cutoff={cutoff:.2f}, window={window}, recall={recall:.2%}, "
"ambient_faph={faph:.3f}".format(
cutoff=best["probability_cutoff"],
window=best["sliding_window_size"],
recall=best["recall"],
faph=best["false_accepts_per_hour"],
)
)
output = {
"probability_cutoff": best["probability_cutoff"],
"sliding_window_size": best["sliding_window_size"],
"target_false_accepts_per_hour": float(args.target_faph),
"selected_false_accepts_per_hour_limit": (
None if math.isinf(selected_limit) else float(selected_limit)
),
"selected_metrics": {
"recall": round(best["recall"], 6),
"false_accepts_per_hour": round(best["false_accepts_per_hour"], 6),
"ambient_hours": round(best["ambient_hours"], 6),
},
"evaluation": {
"positive_dataset": positive_mode,
"ambient_dataset": ambient_mode,
"positive_tracks": len(positive_tracks),
"ambient_tracks": len(ambient_tracks),
"cooldown_slices": int(args.cooldown_slices),
"positive_skip_slices": int(args.positive_skip_slices),
"window_sizes": window_sizes,
"cutoff_min": round(float(cutoffs[0]), 4),
"cutoff_max": round(float(cutoffs[-1]), 4),
"cutoff_step": float(args.cutoff_step),
},
"per_window_best": best_by_window,
"generated_at": datetime.now(timezone.utc).isoformat(),
}
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(output, indent=2) + "\n", encoding="utf-8")
print(f"📝 Wrote calibration to {output_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -130,6 +130,73 @@ print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} fa
EOF EOF
} }
converter_from_dataset_api() {
# shellcheck source=/dev/null
source "${DATA_DIR}/.venv/bin/activate"
python - "${AUDIO16K_DIR}" <<-'EOF'
import sys
from pathlib import Path
import librosa
import numpy as np
import scipy.io.wavfile
from datasets import load_dataset
def write_wav(dst: Path, data: np.ndarray, sr: int):
dst.parent.mkdir(parents=True, exist_ok=True)
x = np.clip(data, -1.0, 1.0)
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
audioset_out = Path(sys.argv[1])
print(" AudioSet FLAC tarballs are unavailable; using Hugging Face datasets API instead.")
dataset = load_dataset(
"agkphysics/AudioSet",
"balanced",
split="train",
streaming=True,
)
audioset_bad = []
ok = 0
skipped = 0
heartbeat_every = 250
for idx, sample in enumerate(dataset, start=1):
try:
video_id = str(sample.get("video_id") or f"audioset_{idx:06d}")
outfile = audioset_out / f"{video_id}.wav"
if outfile.exists():
skipped += 1
continue
audio = sample.get("audio") or {}
y = np.asarray(audio.get("array"))
sr = int(audio.get("sampling_rate") or 0)
if y.size == 0 or sr <= 0:
raise ValueError("missing decoded audio")
if y.ndim > 1:
y = np.mean(y, axis=-1)
if sr != 16000:
y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=16000)
if y.size == 0:
raise ValueError("empty audio")
write_wav(outfile, y, 16000)
ok += 1
except Exception as exc:
audioset_bad.append(f"{sample.get('video_id', idx)}:{exc}")
if idx == 1 or (idx % heartbeat_every) == 0:
print(f" AudioSet API progress: {idx} clips processed (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})")
if audioset_bad:
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
print(f" AudioSet complete via datasets API ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed)")
EOF
}
expected_filecount=$(get_total_filecount filecounts) expected_filecount=$(get_total_filecount filecounts)
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || : actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
write_filecount=false write_filecount=false
@@ -139,40 +206,44 @@ if [ "${actual_filecount}" -ne 0 ] ; then
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert" echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
else else
dl=$(find_rev) dl=$(find_rev)
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; } if [ -z "$dl" ] ; then
rev=${dl%%,*} rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
pattern=${dl##*,} converter_from_dataset_api
else
rev=${dl%%,*}
pattern=${dl##*,}
echo " Checking 10 tarballs" echo " Checking 10 tarballs"
for i in {0..9} ; do for i in {0..9} ; do
fname="downloads/bal_train0${i}.tar" fname="downloads/bal_train0${i}.tar"
if [ ! -f "${fname}" ] ; then if [ ! -f "${fname}" ] ; then
echo " Downloading bal_train0${i}.tar" echo " Downloading bal_train0${i}.tar"
url="${AUDIO_URL}/${rev}/${pattern}${i}.tar" url="${AUDIO_URL}/${rev}/${pattern}${i}.tar"
curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; } curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; }
fi
tarball_filecount=$(tar -tvf "${fname}" | wc -l )
filecounts["bal_train0${i}.tar"]=${tarball_filecount}
write_filecount=true
echo " Untarring bal_train0${i}.tar"
tar -xf "${fname}" -C "${AUDIO_DIR}"
if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then
echo " Cleaning up bal_train0${i}.tar"
rm -rf "${fname}"
fi
done
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
converter
# Recompute counts and warn (but do not fail)
expected_filecount=$(get_total_filecount filecounts)
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
fi fi
tarball_filecount=$(tar -tvf "${fname}" | wc -l )
filecounts["bal_train0${i}.tar"]=${tarball_filecount}
write_filecount=true
echo " Untarring bal_train0${i}.tar"
tar -xf "${fname}" -C "${AUDIO_DIR}"
if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then
echo " Cleaning up bal_train0${i}.tar"
rm -rf "${fname}"
fi
done
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
converter
# Recompute counts and warn (but do not fail)
expected_filecount=$(get_total_filecount filecounts)
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
fi fi
fi fi

View File

@@ -27,8 +27,11 @@ cd "${DATA_DIR}/training_datasets"
echo "***** Checking FMA *****" echo "***** Checking FMA *****"
AUDIO_URL="https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip" AUDIO_URLS=(
AUDIO_ZIPFILE="fma_xs.zip" "https://os.unil.cloud.switch.ch/fma/fma_small.zip"
"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
)
AUDIO_ZIPFILE="fma_small.zip"
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}" AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
AUDIO_DIR="fma" AUDIO_DIR="fma"
mkdir -p "${AUDIO_DIR}" || : mkdir -p "${AUDIO_DIR}" || :
@@ -81,6 +84,52 @@ EOF
} }
extract_zip_with_python() {
local zip_path="$1"
local dest_dir="$2"
"${DATA_DIR}/.venv/bin/python" - "${zip_path}" "${dest_dir}" <<-'EOF'
import sys
import zipfile
from pathlib import Path
from tqdm import tqdm
zip_path = Path(sys.argv[1])
dest_dir = Path(sys.argv[2])
if (not zip_path.exists()) or zip_path.stat().st_size == 0:
raise SystemExit(f"Archive missing or empty: {zip_path}")
with zipfile.ZipFile(zip_path, "r") as zf:
members = zf.infolist()
size_gb = zip_path.stat().st_size / (1024 ** 3)
print(f" Extracting {zip_path.name} ({len(members)} entries, {size_gb:.1f} GiB)...")
for member in tqdm(members, desc=" FMA zip extract", unit="file"):
zf.extract(member, dest_dir)
EOF
}
download_with_fallbacks() {
local output="$1"
shift
local urls=( "$@" )
local rc=1
for url in "${urls[@]}" ; do
for attempt in 1 2 3 4 ; do
curl -sfL "${url}" -o "${output}" && [ -s "${output}" ] && return 0
rc=$?
rm -f "${output}" || :
if [ "${attempt}" -lt 4 ] ; then
echo " Retry ${attempt}/3 after download failure"
sleep $(( attempt * 2 ))
fi
done
done
return "${rc}"
}
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]} expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || : actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || :
write_filecount=false write_filecount=false
@@ -92,13 +141,16 @@ else
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
if [ ! -f "${AUDIO_ZIP}" ] ; then if [ ! -f "${AUDIO_ZIP}" ] ; then
echo " Downloading ${AUDIO_ZIPFILE}" echo " Downloading ${AUDIO_ZIPFILE}"
curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}" download_with_fallbacks "${AUDIO_ZIP}" "${AUDIO_URLS[@]}" || {
echo " Failed to download ${AUDIO_ZIPFILE} from all configured sources." >&2
exit 1
}
fi fi
rm -rf "${AUDIO_DIR}" || : rm -rf "${AUDIO_DIR}" || :
mkdir "${AUDIO_DIR}" mkdir "${AUDIO_DIR}"
echo " Unzipping ${AUDIO_ZIPFILE}" echo " Extracting ${AUDIO_ZIPFILE}"
unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}" extract_zip_with_python "${AUDIO_ZIP}" "${AUDIO_DIR}"
fi fi
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
echo " Cleaning up ${AUDIO_ZIPFILE}" echo " Cleaning up ${AUDIO_ZIPFILE}"
@@ -128,4 +180,3 @@ fi
echo " FMA complete" echo " FMA complete"
exit 0 exit 0

View File

@@ -242,29 +242,7 @@ if [ ! -s "${MODEL_FILE}.json" ] ; then
curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json" curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json"
fi fi
# --- Dutch ONNX voices (single-speaker, used with --language=nl) --- echo " Non-English Piper voices will be downloaded on demand for the selected language."
# Working Dutch voices: pim, ronnie (nl_NL) and nathalie (nl_BE).
# nl_NL-mls-medium is intentionally excluded (known Piper issue: outputs gibberish).
HF_VOICES="https://huggingface.co/rhasspy/piper-voices/resolve/main"
declare -a NL_VOICES=(
"nl/nl_NL/pim/medium/nl_NL-pim-medium"
"nl/nl_NL/ronnie/medium/nl_NL-ronnie-medium"
"nl/nl_BE/nathalie/medium/nl_BE-nathalie-medium"
)
echo " ===== Checking Dutch Piper voices ====="
for voice_path in "${NL_VOICES[@]}" ; do
voice_name="$(basename "${voice_path}")"
onnx_file="${VOICES_DIR}/${voice_name}.onnx"
json_file="${VOICES_DIR}/${voice_name}.onnx.json"
if [ ! -f "${onnx_file}" ] ; then
echo " Downloading ${voice_name}.onnx"
curl -sfL "${HF_VOICES}/${voice_path}.onnx?download=true" -o "${onnx_file}"
fi
if [ ! -f "${json_file}" ] ; then
echo " Downloading ${voice_name}.onnx.json"
curl -sfL "${HF_VOICES}/${voice_path}.onnx.json?download=true" -o "${json_file}"
fi
done
${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu="" ${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu=""
echo " ===== Installing onnxruntime${onnxgpu} =====" echo " ===== Installing onnxruntime${onnxgpu} ====="

View File

@@ -18,6 +18,8 @@ parser.add_argument("--output-dir", type=str, help="Wake word output dir. Defaul
# Personal inputs/outputs (NEW) # Personal inputs/outputs (NEW)
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False) parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False) parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
parser.add_argument("--negative-dir", type=str, help="Reviewed negative WAV dir. Default: <data-dir>/negative_samples", required=False)
parser.add_argument("--negative-output-dir", type=str, help="Reviewed negative features output dir. Default: <data-dir>/work/reviewed_negative_features", required=False)
# Dataset dirs # Dataset dirs
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False) parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
@@ -57,6 +59,17 @@ if not args.personal_output_dir:
else: else:
args.personal_output_dir = os.path.realpath(args.personal_output_dir) args.personal_output_dir = os.path.realpath(args.personal_output_dir)
# Reviewed negative defaults
if not args.negative_dir:
args.negative_dir = os.path.join(args.data_dir, "negative_samples")
else:
args.negative_dir = os.path.realpath(args.negative_dir)
if not args.negative_output_dir:
args.negative_output_dir = os.path.join(work_dir, "reviewed_negative_features")
else:
args.negative_output_dir = os.path.realpath(args.negative_output_dir)
# Dataset defaults # Dataset defaults
if not args.mit_rirs_16k_dir: if not args.mit_rirs_16k_dir:
args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k") args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
@@ -205,7 +218,7 @@ def bind_wav_generator(clips_obj: Clips, wav_dir: str):
clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj) clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str): def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str, *, remove_silence: bool = True):
files = glob.glob(os.path.join(input_wav_dir, "*.wav")) files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
if not files: if not files:
print(f" No WAVs found for {label} in: {input_wav_dir} (skipping)") print(f" No WAVs found for {label} in: {input_wav_dir} (skipping)")
@@ -218,7 +231,7 @@ def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
input_directory=input_wav_dir, input_directory=input_wav_dir,
file_pattern="*.wav", file_pattern="*.wav",
max_clip_duration_s=5, max_clip_duration_s=5,
remove_silence=True, remove_silence=remove_silence,
random_split_seed=10, random_split_seed=10,
split_count=0.1, split_count=0.1,
) )
@@ -263,9 +276,12 @@ def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
# Wake word generated/TTS features (existing behavior) # Wake word generated/TTS features (existing behavior)
generate_feature_set(args.input_dir, args.output_dir, "generated") generate_feature_set(args.input_dir, args.output_dir, "generated")
# Personal features (NEW) # Personal features
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal") generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")
# Reviewed false-positive / hard-negative features
generate_feature_set(args.negative_dir, args.negative_output_dir, "reviewed negatives", remove_silence=False)
END_TIME = datetime.now(timezone.utc).replace(microsecond=0) END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
et = END_TIME - START_TIME et = END_TIME - START_TIME
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")

View File

@@ -111,6 +111,16 @@ else
echo " No personal features found at ${PERSONAL_FEATURES_DIR}/training (continuing without personal weighting)" echo " No personal features found at ${PERSONAL_FEATURES_DIR}/training (continuing without personal weighting)"
fi fi
# Reviewed false-positive features are optional hard negatives.
REVIEWED_NEGATIVE_FEATURES_DIR="${WORK_DIR}/reviewed_negative_features"
HAS_REVIEWED_NEGATIVE="false"
if [ -d "${REVIEWED_NEGATIVE_FEATURES_DIR}/training" ] ; then
HAS_REVIEWED_NEGATIVE="true"
echo "✅ Found reviewed negative features: ${REVIEWED_NEGATIVE_FEATURES_DIR}/training (will weight as hard negatives)"
else
echo " No reviewed negative features found at ${REVIEWED_NEGATIVE_FEATURES_DIR}/training (continuing with stock negatives)"
fi
cd "${WORK_DIR}" cd "${WORK_DIR}"
echo "===== Starting ${TRAINING_STEPS} training steps =====" echo "===== Starting ${TRAINING_STEPS} training steps ====="
@@ -133,6 +143,7 @@ features:
truth: true truth: true
type: mmap type: mmap
__PERSONAL_FEATURE_MARKER__ __PERSONAL_FEATURE_MARKER__
__REVIEWED_NEGATIVE_FEATURE_MARKER__
- features_dir: __NEG_SPEECH__ - features_dir: __NEG_SPEECH__
penalty_weight: 1.0 penalty_weight: 1.0
sampling_weight: 12.0 sampling_weight: 12.0
@@ -208,6 +219,22 @@ else
sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}" sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}"
fi fi
# Insert/remove reviewed hard-negative block
if [ "${HAS_REVIEWED_NEGATIVE}" = "true" ]; then
reviewed_negative_block="$(cat <<EOF
- features_dir: ${REVIEWED_NEGATIVE_FEATURES_DIR}
penalty_weight: 1.25
sampling_weight: 8.0
truncation_strategy: random
truth: false
type: mmap
EOF
)"
perl -0777 -i -pe "s#__REVIEWED_NEGATIVE_FEATURE_MARKER__#${reviewed_negative_block}#g" "${YAML_PATH}"
else
sed -i -e "/__REVIEWED_NEGATIVE_FEATURE_MARKER__/d" "${YAML_PATH}"
fi
echo " Wrote training_parameters.yaml" echo " Wrote training_parameters.yaml"
rm -rf "${WORK_DIR}/trained_models/wakeword" rm -rf "${WORK_DIR}/trained_models/wakeword"
@@ -317,6 +344,7 @@ fi
TRAINING_DONE="false" TRAINING_DONE="false"
echo "🏋️ Starting model training and TFLite export (this is the longest stage)…"
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
echo "✅ Training complete (GPU path)." echo "✅ Training complete (GPU path)."
TRAINING_DONE="true" TRAINING_DONE="true"
@@ -386,12 +414,24 @@ if [ "${TRAINING_DONE}" != "true" ]; then
fi fi
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite" source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json"
if [ ! -f "${source_path}" ] ; then if [ ! -f "${source_path}" ] ; then
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}" echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
exit 1 exit 1
fi fi
echo "🎯 Calibrating detector settings for on-device use…"
if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \
--training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \
--model "${source_path}" \
--output "${calibration_path}"; then
echo "✅ Detector calibration complete."
else
echo "⚠️ Detector calibration failed; packaging with default detector settings."
rm -f "${calibration_path}" || :
fi
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || : cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || : cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || : cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
@@ -404,24 +444,49 @@ tflite_path="${OUTPUT_DIR}/${tflite_filename}"
cp "${source_path}" "${tflite_path}" cp "${source_path}" "${tflite_path}"
json_path="${OUTPUT_DIR}/${wake_word_filename}.json" json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
cat <<-EOF > "${json_path}" export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}"
{ echo "📦 Packaging final model artifacts…"
"${PYTHON_BIN:-python}" - <<'PY'
import json
import os
from pathlib import Path
json_path = Path(os.environ["JSON_PATH"])
calibration_path = Path(os.environ.get("CALIBRATION_PATH", ""))
language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower()
probability_cutoff = 0.97
sliding_window_size = 5
if calibration_path.exists():
try:
calibration = json.loads(calibration_path.read_text(encoding="utf-8"))
probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff))
sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size))
print(
f"🎯 Using calibrated detector settings: "
f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}"
)
except Exception as exc:
print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.")
meta = {
"type": "micro", "type": "micro",
"wake_word": "${WAKE_WORD_TITLE}", "wake_word": os.environ["WAKE_WORD_TITLE"],
"author": "Tater Totterson", "author": "Tater Totterson",
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git", "website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
"model": "${tflite_filename}", "model": os.environ["TFLITE_FILENAME"],
"trained_languages": ["en"], "trained_languages": [language],
"version": 2, "version": 2,
"micro": { "micro": {
"probability_cutoff": 0.97, "probability_cutoff": round(probability_cutoff, 2),
"sliding_window_size": 5, "sliding_window_size": sliding_window_size,
"feature_step_size": 10, "feature_step_size": 10,
"tensor_arena_size": 30000, "tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0" "minimum_esphome_version": "2024.7.0",
} },
} }
EOF json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8")
PY
echo "Name: ${WAKE_WORD_TITLE}" echo "Name: ${WAKE_WORD_TITLE}"
echo "Model: ${tflite_path}" echo "Model: ${tflite_path}"

View File

@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# System deps # System deps
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \ python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \
git wget curl unzip ca-certificates nano less \ git wget curl unzip patch ninja-build ca-certificates nano less \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& mkdir -p /data && mkdir -p /data
@@ -22,7 +22,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/
# Root-level entrypoints # Root-level entrypoints
COPY --chown=root:root --chmod=0755 \ COPY --chown=root:root --chmod=0755 \
train_wake_word \ train_wake_word \
run_recorder.sh \ run.sh \
trainer_server.py \ trainer_server.py \
requirements.txt \ requirements.txt \
/root/mww-scripts/ /root/mww-scripts/
@@ -37,4 +37,4 @@ RUN chmod -R a+x /root/mww-scripts/cli
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
# trainer server # trainer server
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"] CMD ["/bin/bash", "-lc", "/root/mww-scripts/run.sh"]

BIN
images/tater-repo-logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 590 KiB

View File

@@ -6,9 +6,9 @@ ROOTDIR="$(dirname "$(realpath "$0")")"
# Training convention # Training convention
DATA_DIR="${DATA_DIR:-/data}" DATA_DIR="${DATA_DIR:-/data}"
HOST="${REC_HOST:-0.0.0.0}" HOST="${REC_HOST:-0.0.0.0}"
PORT="${REC_PORT:-8888}" PORT="${REC_PORT:-8789}"
# Keep recorder deps separate from training venv # Keep trainer UI deps separate from the training venv
VENV_DIR="${DATA_DIR}/.recorder-venv" VENV_DIR="${DATA_DIR}/.recorder-venv"
PY="${VENV_DIR}/bin/python" PY="${VENV_DIR}/bin/python"
PIP="${PY} -m pip" PIP="${PY} -m pip"
@@ -17,6 +17,7 @@ PIN_FILE="${VENV_DIR}/.pinned_installed"
FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}" FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}" UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}" PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
ESPHOME_VERSION="${REC_ESPHOME_VERSION:-2026.5.1}"
echo "microWakeWord Trainer UI (Docker)" echo "microWakeWord Trainer UI (Docker)"
echo "-> ROOTDIR: ${ROOTDIR}" echo "-> ROOTDIR: ${ROOTDIR}"
@@ -25,6 +26,16 @@ echo "-> URL: http://localhost:${PORT}/"
mkdir -p "${DATA_DIR}" mkdir -p "${DATA_DIR}"
install_ui_deps() {
${PIP} install \
"fastapi==${FASTAPI_VERSION}" \
"uvicorn[standard]==${UVICORN_VERSION}" \
"python-multipart==${PY_MULTIPART_VERSION}" \
"esphome==${ESPHOME_VERSION}" \
"silero-vad>=5.0.0" \
"numpy>=1.24.0"
}
# ----------------------------- # -----------------------------
# Trainer UI venv (separate) # Trainer UI venv (separate)
# ----------------------------- # -----------------------------
@@ -39,21 +50,63 @@ source "${VENV_DIR}/bin/activate"
if [[ ! -f "${PIN_FILE}" ]]; then if [[ ! -f "${PIN_FILE}" ]]; then
echo "Installing pinned trainer UI deps" echo "Installing pinned trainer UI deps"
${PIP} install -U pip setuptools wheel ${PIP} install -U pip setuptools wheel
${PIP} install \ install_ui_deps
"fastapi==${FASTAPI_VERSION}" \
"uvicorn[standard]==${UVICORN_VERSION}" \
"python-multipart==${PY_MULTIPART_VERSION}"
touch "${PIN_FILE}" touch "${PIN_FILE}"
else else
echo "Reusing existing trainer UI venv (no upgrades)" echo "Reusing existing trainer UI venv (no upgrades)"
fi if ! "${PY}" - "${FASTAPI_VERSION}" "${UVICORN_VERSION}" "${PY_MULTIPART_VERSION}" "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1
import importlib.metadata as md
import sys
fastapi_version, uvicorn_version, multipart_version, esphome_version = sys.argv[1:5]
def version_tuple(value):
parts = []
for token in str(value).replace("-", ".").split("."):
if token.isdigit():
parts.append(int(token))
else:
digits = "".join(ch for ch in token if ch.isdigit())
if digits:
parts.append(int(digits))
break
return tuple(parts)
exact = {
"fastapi": fastapi_version,
"uvicorn": uvicorn_version,
"python-multipart": multipart_version,
"esphome": esphome_version,
}
minimum = {
"silero-vad": "5.0.0",
"numpy": "1.24.0",
}
present = ("torch", "zeroconf")
for package, expected in exact.items():
if md.version(package) != expected:
raise SystemExit(1)
for package, minimum_version in minimum.items():
if version_tuple(md.version(package)) < version_tuple(minimum_version):
raise SystemExit(1)
for package in present:
md.version(package)
PY
then
echo "UI dependencies missing or stale; installing recorder dependencies"
install_ui_deps
fi
fi
# ----------------------------- # -----------------------------
# Trainer server env # Trainer server env
# ----------------------------- # -----------------------------
export DATA_DIR="${DATA_DIR}" export DATA_DIR="${DATA_DIR}"
export STATIC_DIR="${ROOTDIR}/static" export STATIC_DIR="${ROOTDIR}/static"
export PERSONAL_DIR="${DATA_DIR}/personal_samples" export PERSONAL_DIR="${DATA_DIR}/personal_samples"
export CAPTURED_DIR="${DATA_DIR}/captured_audio"
export NEGATIVE_DIR="${DATA_DIR}/negative_samples"
export TRAINED_WAKE_WORDS_DIR="${DATA_DIR}/trained_wake_words"
# IMPORTANT: leave training venv creation to /api/train inside trainer_server.py # IMPORTANT: leave training venv creation to /api/train inside trainer_server.py
# but still set TRAIN_CMD so the server knows how to invoke training once ready # but still set TRAIN_CMD so the server knows how to invoke training once ready

File diff suppressed because it is too large Load Diff

View File

@@ -150,6 +150,7 @@ if ${CLEANUP_WORK_DIR} ; then
"${DATA_DIR}/work/wake_word_samples" \ "${DATA_DIR}/work/wake_word_samples" \
"${DATA_DIR}/work/wake_word_samples_augmented" \ "${DATA_DIR}/work/wake_word_samples_augmented" \
"${DATA_DIR}/work/personal_augmented_features" \ "${DATA_DIR}/work/personal_augmented_features" \
"${DATA_DIR}/work/reviewed_negative_features" \
"${DATA_DIR}/work/last_wake_word" || : "${DATA_DIR}/work/last_wake_word" || :
fi fi

File diff suppressed because it is too large Load Diff