mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51cbf6fd90 | ||
|
|
6e7396455a | ||
|
|
2da9f7a686 | ||
|
|
7b028e4420 | ||
|
|
b3d9f0e369 | ||
|
|
240ca7682e | ||
|
|
18e5fcd000 |
251
README.md
251
README.md
@@ -1,25 +1,20 @@
|
||||
<div align="center">
|
||||
<h1>🎙️ microWakeWord Nvidia Trainer & Recorder</h1>
|
||||
<img width="1002" height="593" alt="Screenshot 2026-01-18 at 8 13 35 AM" src="https://github.com/user-attachments/assets/e1411d8a-8638-4df8-992b-09a46c6e5ddc" />
|
||||
<h1>microWakeWord NVIDIA Docker Trainer UI</h1>
|
||||
<img width="800" alt="Screenshot 2026-04-14 at 11 02 06 PM" src="https://github.com/user-attachments/assets/694f4cb7-e4d8-4e2b-80ec-b40fb41cbfff" />
|
||||
</div>
|
||||
|
||||
Train **microWakeWord** detection models using a simple **web-based recorder + trainer UI**, packaged in a Docker container.
|
||||
Train custom microWakeWord models in Docker with:
|
||||
|
||||
No Jupyter notebooks required. No manual cell execution. Just record your voice (optional) and train.
|
||||
- uploaded personal voice samples
|
||||
- automatically generated Piper TTS samples
|
||||
- a browser-based trainer UI
|
||||
- live training logs in a popup console
|
||||
|
||||
This project no longer records audio in the browser. The UI is now upload-first: users add their own audio files, the app validates or converts them, and training runs from the same page.
|
||||
|
||||
---
|
||||
|
||||
<img width="100" height="44" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/87351bed-3321-4a43-924f-fecf2e4e700f" />
|
||||
|
||||
**microWakeWord_Trainer-Nvidia** is available in the **Unraid Community Apps** store.
|
||||
Install directly from the Unraid App Store with a one-click template.
|
||||
|
||||
---
|
||||
|
||||
<img width="100" height="56" alt="unraid_logo_black-339076895" src="https://github.com/user-attachments/assets/bf959585-ae13-4b4d-ae62-4202a850d35a" />
|
||||
|
||||
|
||||
### Pull the Docker Image
|
||||
## Docker Image
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/tatertotterson/microwakeword:latest
|
||||
@@ -27,7 +22,7 @@ docker pull ghcr.io/tatertotterson/microwakeword:latest
|
||||
|
||||
---
|
||||
|
||||
### Run the Container
|
||||
## Run The Container
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
@@ -37,133 +32,143 @@ docker run -d \
|
||||
ghcr.io/tatertotterson/microwakeword:latest
|
||||
```
|
||||
|
||||
**What these flags do:**
|
||||
- `--gpus all` → Enables GPU acceleration
|
||||
- `-p 8888:8888` → Exposes the Recorder + Trainer WebUI
|
||||
- `-v $(pwd):/data` → Persists all models, datasets, and cache
|
||||
What these flags do:
|
||||
|
||||
---
|
||||
- `--gpus all` enables GPU acceleration
|
||||
- `-p 8888:8888` exposes the trainer UI
|
||||
- `-v $(pwd):/data` persists models, downloaded voices, datasets, and personal samples
|
||||
|
||||
### Open the Recorder WebUI
|
||||
|
||||
Open your browser and go to:
|
||||
|
||||
👉 **http://localhost:8888**
|
||||
|
||||
You’ll see the **microWakeWord Recorder & Trainer UI**.
|
||||
|
||||
---
|
||||
|
||||
## 🎤 Recording Voice Samples (Optional)
|
||||
|
||||
Personal voice recordings are **optional**.
|
||||
|
||||
- You may **record your own voice** for better accuracy
|
||||
- Or simply **click “Train” without recording anything**
|
||||
|
||||
If no recordings are present, training will proceed using **synthetic TTS samples only**.
|
||||
|
||||
### Remote systems (important)
|
||||
If you are running this on a **remote PC / server**, browser-based recording will not work unless:
|
||||
- You use a **reverse proxy** (HTTPS + mic permissions), **or**
|
||||
- You access the UI via **localhost** on the same machine
|
||||
|
||||
Training itself works fine remotely — only recording requires local microphone access.
|
||||
|
||||
---
|
||||
|
||||
### 🎙️ Recording Flow
|
||||
|
||||
1. Enter your wake word
|
||||
2. Test pronunciation with **Test TTS**
|
||||
3. Choose:
|
||||
- Number of speakers (e.g. family members)
|
||||
- Takes per speaker (default: 10)
|
||||
4. Click **Begin recording**
|
||||
5. Speak naturally — recording:
|
||||
- Starts when you talk
|
||||
- Stops automatically after silence
|
||||
6. Repeat for each speaker
|
||||
|
||||
Files are saved automatically to:
|
||||
|
||||
```
|
||||
personal_samples/
|
||||
speaker01_take01.wav
|
||||
speaker01_take02.wav
|
||||
speaker02_take01.wav
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧠 Training Behavior (Important Notes)
|
||||
|
||||
### ⏬ First training run
|
||||
The **first time you click Train**, the system will download **large training datasets** (background noise, speech corpora, etc.).
|
||||
|
||||
- This can take **several minutes**
|
||||
- This happens **only once**
|
||||
- Data is cached inside `/data`
|
||||
|
||||
You **will NOT need to download these again** unless you delete `/data`.
|
||||
|
||||
---
|
||||
|
||||
### 🔁 Re-training is safe and incremental
|
||||
|
||||
- You can train **multiple wake words** back-to-back
|
||||
- You do **NOT** need to clear any folders between runs
|
||||
- Old models are preserved in timestamped output directories
|
||||
- All required cleanup and reuse logic is handled automatically
|
||||
|
||||
---
|
||||
|
||||
## 📦 Output Files
|
||||
|
||||
When training completes, you’ll get:
|
||||
- `<wake_word>.tflite` – quantized streaming model
|
||||
- `<wake_word>.json` – ESPHome-compatible metadata
|
||||
|
||||
Both are saved under:
|
||||
Then open:
|
||||
|
||||
```text
|
||||
/data/output/
|
||||
http://localhost:8888
|
||||
```
|
||||
|
||||
Each run is placed in its own timestamped folder.
|
||||
---
|
||||
|
||||
## What The UI Does
|
||||
|
||||
- Start a wake word session
|
||||
- Test TTS pronunciation
|
||||
- Upload one or many personal samples
|
||||
- Normalize uploads to `16 kHz / mono / 16-bit PCM WAV`
|
||||
- Train with or without personal samples
|
||||
- Show a popup console with live progress and logs
|
||||
|
||||
Personal samples are optional. If none are uploaded, the trainer can still proceed with TTS-only data after confirmation.
|
||||
|
||||
---
|
||||
|
||||
## 🎤 Optional: Personal Voice Samples (Advanced)
|
||||
## Personal Samples
|
||||
|
||||
If you record personal samples:
|
||||
- They are automatically augmented
|
||||
- They are **up-weighted during training**
|
||||
- This significantly improves real-world accuracy
|
||||
Accepted upload formats include:
|
||||
|
||||
No configuration required — detection is automatic.
|
||||
- WAV
|
||||
- MP3
|
||||
- M4A
|
||||
- FLAC
|
||||
- OGG
|
||||
- AAC
|
||||
- OPUS
|
||||
- WEBM
|
||||
|
||||
The backend validates or converts uploads with `ffmpeg` and stores the normalized files in:
|
||||
|
||||
```text
|
||||
/data/personal_samples/
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- starting a new session does not clear personal samples
|
||||
- use the `Clear personal samples` button if you want to wipe them
|
||||
- any uploaded personal samples are automatically included in training
|
||||
|
||||
---
|
||||
|
||||
## 🔄 Resetting Everything (Optional)
|
||||
## Language Support
|
||||
|
||||
If you want a **completely clean slate**:
|
||||
The language selector is dynamic.
|
||||
|
||||
Delete the /data folder
|
||||
- `en` is always available
|
||||
- non-English languages are populated from Piper voice metadata
|
||||
- when you train with a non-English language, the backend downloads all Piper ONNX voices for that selected language only
|
||||
- it does not pre-download every language
|
||||
- already-downloaded voices are reused on later runs
|
||||
|
||||
Then restart the container.
|
||||
English stays on its existing dedicated generator model path. Non-English languages use the selected language's ONNX Piper voices.
|
||||
|
||||
⚠️ This will:
|
||||
- Remove cached datasets
|
||||
- Require re-downloading training data
|
||||
- Delete trained models
|
||||
If the Piper catalog is unavailable, already-installed local voices can still be used.
|
||||
|
||||
---
|
||||
|
||||
## 🙌 Credits
|
||||
## Training Behavior
|
||||
|
||||
Built on top of the excellent
|
||||
**https://github.com/kahrendt/microWakeWord**
|
||||
1. Enter the wake word
|
||||
2. Optionally test pronunciation
|
||||
3. Optionally upload personal samples
|
||||
4. Click `Start training`
|
||||
5. Watch the popup console for:
|
||||
- selected-language voice downloads when needed
|
||||
- sample generation progress
|
||||
- dataset setup
|
||||
- training progress and completion
|
||||
|
||||
Huge thanks to the original authors ❤️
|
||||
The `Open console` button lets you reopen the log window after closing it.
|
||||
|
||||
---
|
||||
|
||||
## First Run Notes
|
||||
|
||||
The first real training run may download large training assets into `/data`, such as:
|
||||
|
||||
- Piper voices for the selected language
|
||||
- training datasets and background data
|
||||
- Python training environment dependencies
|
||||
|
||||
These are reused later unless you delete `/data`.
|
||||
|
||||
---
|
||||
|
||||
## Output Files
|
||||
|
||||
Successful runs produce:
|
||||
|
||||
```text
|
||||
/data/output/<wake_word>.tflite
|
||||
/data/output/<wake_word>.json
|
||||
```
|
||||
|
||||
If those files already exist, the trainer creates timestamped backups before replacing them.
|
||||
|
||||
---
|
||||
|
||||
## Resetting Everything
|
||||
|
||||
If you want a clean slate, stop the container and remove the contents of your mounted `/data` directory.
|
||||
|
||||
That will remove:
|
||||
|
||||
- personal samples
|
||||
- downloaded Piper voices
|
||||
- cached datasets
|
||||
- training environments
|
||||
- trained models
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
- browser microphone recording has been removed
|
||||
- personal samples are optional
|
||||
- the server module is now `trainer_server.py`
|
||||
- the launcher script is now `run.sh`
|
||||
|
||||
---
|
||||
|
||||
## Credits
|
||||
|
||||
Built on top of:
|
||||
|
||||
- [microWakeWord](https://github.com/kahrendt/microWakeWord)
|
||||
- [piper-sample-generator](https://github.com/rhasspy/piper-sample-generator)
|
||||
|
||||
405
cli/calibrate_detector.py
Normal file
405
cli/calibrate_detector.py
Normal file
@@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Choose detector metadata that better matches the trained model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from microwakeword.data import FeatureHandler
|
||||
from microwakeword.inference import Model
|
||||
|
||||
|
||||
DEFAULT_WINDOW_SIZES = [3, 4, 5, 6, 7]
|
||||
DEFAULT_TARGET_FAPH = float(os.environ.get("MWW_CALIBRATION_TARGET_FAPH", "1.0"))
|
||||
DEFAULT_COOLDOWN_SLICES = int(os.environ.get("MWW_CALIBRATION_COOLDOWN_SLICES", "25"))
|
||||
DEFAULT_POSITIVE_SKIP_SLICES = int(
|
||||
os.environ.get("MWW_CALIBRATION_POSITIVE_SKIP_SLICES", "25")
|
||||
)
|
||||
DEFAULT_CUTOFF_STEP = float(os.environ.get("MWW_CALIBRATION_CUTOFF_STEP", "0.01"))
|
||||
DEFAULT_CUTOFF_MIN = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MIN", "0.00"))
|
||||
DEFAULT_CUTOFF_MAX = float(os.environ.get("MWW_CALIBRATION_CUTOFF_MAX", "1.00"))
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Calibrate microWakeWord detector metadata from validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-config",
|
||||
default="trained_models/wakeword/training_config.yaml",
|
||||
help="Path to the saved microWakeWord training_config.yaml file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=(
|
||||
"trained_models/wakeword/tflite_stream_state_internal_quant/"
|
||||
"stream_state_internal_quant.tflite"
|
||||
),
|
||||
help="Path to the quantized streaming TFLite model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=(
|
||||
"trained_models/wakeword/tflite_stream_state_internal_quant/"
|
||||
"detection_calibration.json"
|
||||
),
|
||||
help="Where to write the selected detector settings as JSON.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window-sizes",
|
||||
default=",".join(str(value) for value in DEFAULT_WINDOW_SIZES),
|
||||
help="Comma-separated sliding window sizes to evaluate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-faph",
|
||||
type=float,
|
||||
default=DEFAULT_TARGET_FAPH,
|
||||
help="Target ambient false accepts per hour for the selected operating point.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cooldown-slices",
|
||||
type=int,
|
||||
default=DEFAULT_COOLDOWN_SLICES,
|
||||
help="Cooldown slices to use when estimating false accepts per hour.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--positive-skip-slices",
|
||||
type=int,
|
||||
default=DEFAULT_POSITIVE_SKIP_SLICES,
|
||||
help="Initial streaming slices to ignore when scoring positive examples.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cutoff-step",
|
||||
type=float,
|
||||
default=DEFAULT_CUTOFF_STEP,
|
||||
help="Cutoff increment to evaluate between cutoff-min and cutoff-max.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cutoff-min",
|
||||
type=float,
|
||||
default=DEFAULT_CUTOFF_MIN,
|
||||
help="Minimum cutoff to evaluate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cutoff-max",
|
||||
type=float,
|
||||
default=DEFAULT_CUTOFF_MAX,
|
||||
help="Maximum cutoff to evaluate.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_window_sizes(raw: str) -> list[int]:
|
||||
values = []
|
||||
for item in (raw or "").split(","):
|
||||
item = item.strip()
|
||||
if not item:
|
||||
continue
|
||||
value = int(item)
|
||||
if value < 1:
|
||||
raise ValueError("window sizes must be >= 1")
|
||||
values.append(value)
|
||||
if not values:
|
||||
raise ValueError("at least one window size is required")
|
||||
return sorted(set(values))
|
||||
|
||||
|
||||
def _moving_average(values: Sequence[float], window_size: int) -> np.ndarray:
|
||||
array = np.asarray(values, dtype=np.float32)
|
||||
if array.size == 0:
|
||||
return array
|
||||
if window_size <= 1:
|
||||
return array
|
||||
if array.size < window_size:
|
||||
return np.asarray([float(array.mean())], dtype=np.float32)
|
||||
cumsum = np.cumsum(np.insert(array, 0, 0.0))
|
||||
averaged = (cumsum[window_size:] - cumsum[:-window_size]) / float(window_size)
|
||||
return averaged.astype(np.float32)
|
||||
|
||||
|
||||
def _compute_false_accepts_per_hour(
|
||||
probabilities_per_track: Iterable[np.ndarray],
|
||||
cutoffs: np.ndarray,
|
||||
cooldown_slices: int,
|
||||
stride: int,
|
||||
step_seconds: float,
|
||||
) -> tuple[np.ndarray, float]:
|
||||
cutoffs = np.asarray(cutoffs, dtype=np.float32)
|
||||
false_accepts = np.zeros(cutoffs.shape[0], dtype=np.float64)
|
||||
duration_hours = 0.0
|
||||
|
||||
for track_probabilities in probabilities_per_track:
|
||||
if track_probabilities.size == 0:
|
||||
continue
|
||||
duration_hours += (
|
||||
len(track_probabilities) * stride * step_seconds / 3600.0
|
||||
)
|
||||
cooldown = np.full(cutoffs.shape[0], cooldown_slices, dtype=np.int32)
|
||||
for probability in track_probabilities:
|
||||
cooldown = np.maximum(cooldown - 1, 0)
|
||||
accepted = (cooldown == 0) & (probability > cutoffs)
|
||||
false_accepts += accepted.astype(np.float64)
|
||||
cooldown = np.where(accepted, cooldown_slices, cooldown)
|
||||
|
||||
if duration_hours <= 0:
|
||||
return np.full(cutoffs.shape[0], math.inf, dtype=np.float64), 0.0
|
||||
|
||||
return false_accepts / duration_hours, duration_hours
|
||||
|
||||
|
||||
def _select_best_candidate(
|
||||
candidates: list[dict[str, float]],
|
||||
target_faph: float,
|
||||
) -> tuple[dict[str, float], float]:
|
||||
fallback_limits = [
|
||||
target_faph,
|
||||
max(target_faph * 2.0, target_faph + 0.5),
|
||||
max(target_faph * 4.0, 2.0),
|
||||
]
|
||||
|
||||
def tier(candidate: dict[str, float]) -> int:
|
||||
for index, limit in enumerate(fallback_limits):
|
||||
if candidate["false_accepts_per_hour"] <= limit + 1e-9:
|
||||
return index
|
||||
return len(fallback_limits)
|
||||
|
||||
best = min(
|
||||
candidates,
|
||||
key=lambda candidate: (
|
||||
tier(candidate),
|
||||
-candidate["recall"],
|
||||
candidate["false_accepts_per_hour"],
|
||||
abs(candidate["sliding_window_size"] - 5),
|
||||
-candidate["probability_cutoff"],
|
||||
),
|
||||
)
|
||||
|
||||
tier_index = tier(best)
|
||||
if tier_index < len(fallback_limits):
|
||||
return best, fallback_limits[tier_index]
|
||||
return best, float("inf")
|
||||
|
||||
|
||||
def _load_config(config_path: Path) -> dict:
|
||||
with config_path.open("r", encoding="utf-8") as handle:
|
||||
return yaml.load(handle.read(), Loader=yaml.Loader)
|
||||
|
||||
|
||||
def _load_eval_sets(
|
||||
handler: FeatureHandler,
|
||||
config: dict,
|
||||
) -> tuple[str, str, list[np.ndarray], list[np.ndarray]]:
|
||||
for positive_mode, ambient_mode in (
|
||||
("validation", "validation_ambient"),
|
||||
("testing", "testing_ambient"),
|
||||
):
|
||||
positive_tracks, labels, _ = handler.get_data(
|
||||
positive_mode,
|
||||
batch_size=config["batch_size"],
|
||||
features_length=config["spectrogram_length"],
|
||||
truncation_strategy="none",
|
||||
)
|
||||
ambient_tracks, _, _ = handler.get_data(
|
||||
ambient_mode,
|
||||
batch_size=config["batch_size"],
|
||||
features_length=config["spectrogram_length"],
|
||||
truncation_strategy="none",
|
||||
)
|
||||
positives = [
|
||||
np.asarray(track)
|
||||
for track, label in zip(positive_tracks, labels)
|
||||
if bool(label)
|
||||
]
|
||||
ambient = [np.asarray(track) for track in ambient_tracks]
|
||||
if positives and ambient:
|
||||
return positive_mode, ambient_mode, positives, ambient
|
||||
raise RuntimeError(
|
||||
"No suitable validation/testing data was found for detector calibration."
|
||||
)
|
||||
|
||||
|
||||
def _predict_tracks(
|
||||
model: Model,
|
||||
tracks: Sequence[np.ndarray],
|
||||
label: str,
|
||||
) -> list[np.ndarray]:
|
||||
predictions: list[np.ndarray] = []
|
||||
total = len(tracks)
|
||||
print(f"→ Running streaming inference on {total} {label} track(s)")
|
||||
for index, track in enumerate(tracks, start=1):
|
||||
values = np.asarray(model.predict_spectrogram(track), dtype=np.float32)
|
||||
predictions.append(values)
|
||||
if index == total or index % 25 == 0:
|
||||
print(f" {label}: {index}/{total}")
|
||||
return predictions
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
window_sizes = _parse_window_sizes(args.window_sizes)
|
||||
if args.cutoff_step <= 0:
|
||||
raise ValueError("cutoff-step must be > 0")
|
||||
if args.cutoff_max < args.cutoff_min:
|
||||
raise ValueError("cutoff-max must be >= cutoff-min")
|
||||
|
||||
config_path = Path(args.training_config)
|
||||
model_path = Path(args.model)
|
||||
output_path = Path(args.output)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Training config not found: {config_path}")
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Streaming TFLite model not found: {model_path}")
|
||||
|
||||
cutoffs = np.arange(
|
||||
args.cutoff_min,
|
||||
args.cutoff_max + (args.cutoff_step / 2.0),
|
||||
args.cutoff_step,
|
||||
dtype=np.float32,
|
||||
)
|
||||
cutoffs = np.clip(cutoffs, 0.0, 1.0)
|
||||
cutoffs = np.unique(np.round(cutoffs, 4))
|
||||
|
||||
print("===== Detector Calibration =====")
|
||||
print(f"→ Model: {model_path}")
|
||||
print(f"→ Training config: {config_path}")
|
||||
print(
|
||||
f"→ Evaluating window sizes {window_sizes} with target <= "
|
||||
f"{args.target_faph:.2f} false accepts/hour"
|
||||
)
|
||||
|
||||
config = _load_config(config_path)
|
||||
config["flags"] = config.get("flags", {})
|
||||
handler = FeatureHandler(config)
|
||||
|
||||
positive_mode, ambient_mode, positive_tracks, ambient_tracks = _load_eval_sets(
|
||||
handler, config
|
||||
)
|
||||
|
||||
print(
|
||||
f"→ Using {positive_mode} positives ({len(positive_tracks)}) and "
|
||||
f"{ambient_mode} ambient tracks ({len(ambient_tracks)})"
|
||||
)
|
||||
|
||||
model = Model(str(model_path), stride=config["stride"])
|
||||
positive_predictions = _predict_tracks(model, positive_tracks, "positive")
|
||||
ambient_predictions = _predict_tracks(model, ambient_tracks, "ambient")
|
||||
|
||||
candidates: list[dict[str, float]] = []
|
||||
best_by_window: list[dict[str, float]] = []
|
||||
step_seconds = config["window_step_ms"] / 1000.0
|
||||
|
||||
for window_size in window_sizes:
|
||||
ambient_averages = [
|
||||
_moving_average(track, window_size) for track in ambient_predictions
|
||||
]
|
||||
positive_maxima = []
|
||||
for track in positive_predictions:
|
||||
search = (
|
||||
track[args.positive_skip_slices :]
|
||||
if track.size > args.positive_skip_slices
|
||||
else track
|
||||
)
|
||||
averaged = _moving_average(search, window_size)
|
||||
if averaged.size == 0:
|
||||
averaged = _moving_average(track, window_size)
|
||||
positive_maxima.append(float(np.max(averaged)) if averaged.size else 0.0)
|
||||
|
||||
positive_maxima_array = np.asarray(positive_maxima, dtype=np.float32)
|
||||
recall_by_cutoff = np.mean(
|
||||
positive_maxima_array[None, :] > cutoffs[:, None], axis=1
|
||||
)
|
||||
faph_by_cutoff, ambient_hours = _compute_false_accepts_per_hour(
|
||||
ambient_averages,
|
||||
cutoffs,
|
||||
args.cooldown_slices,
|
||||
stride=config["stride"],
|
||||
step_seconds=step_seconds,
|
||||
)
|
||||
|
||||
window_candidates = []
|
||||
for cutoff, recall, faph in zip(cutoffs, recall_by_cutoff, faph_by_cutoff):
|
||||
candidate = {
|
||||
"probability_cutoff": float(round(float(cutoff), 2)),
|
||||
"sliding_window_size": int(window_size),
|
||||
"recall": float(recall),
|
||||
"false_accepts_per_hour": float(faph),
|
||||
"ambient_hours": float(ambient_hours),
|
||||
}
|
||||
candidates.append(candidate)
|
||||
window_candidates.append(candidate)
|
||||
|
||||
best_window, _ = _select_best_candidate(window_candidates, args.target_faph)
|
||||
best_by_window.append(best_window)
|
||||
print(
|
||||
" window={window}: cutoff={cutoff:.2f}; recall={recall:.2%}; "
|
||||
"ambient_faph={faph:.3f}".format(
|
||||
window=window_size,
|
||||
cutoff=best_window["probability_cutoff"],
|
||||
recall=best_window["recall"],
|
||||
faph=best_window["false_accepts_per_hour"],
|
||||
)
|
||||
)
|
||||
|
||||
best, selected_limit = _select_best_candidate(candidates, args.target_faph)
|
||||
if best["false_accepts_per_hour"] > args.target_faph + 1e-9:
|
||||
print(
|
||||
"⚠️ No candidate met the target false accepts/hour budget; "
|
||||
"using the best fallback operating point."
|
||||
)
|
||||
|
||||
print(
|
||||
"✓ Selected cutoff={cutoff:.2f}, window={window}, recall={recall:.2%}, "
|
||||
"ambient_faph={faph:.3f}".format(
|
||||
cutoff=best["probability_cutoff"],
|
||||
window=best["sliding_window_size"],
|
||||
recall=best["recall"],
|
||||
faph=best["false_accepts_per_hour"],
|
||||
)
|
||||
)
|
||||
|
||||
output = {
|
||||
"probability_cutoff": best["probability_cutoff"],
|
||||
"sliding_window_size": best["sliding_window_size"],
|
||||
"target_false_accepts_per_hour": float(args.target_faph),
|
||||
"selected_false_accepts_per_hour_limit": (
|
||||
None if math.isinf(selected_limit) else float(selected_limit)
|
||||
),
|
||||
"selected_metrics": {
|
||||
"recall": round(best["recall"], 6),
|
||||
"false_accepts_per_hour": round(best["false_accepts_per_hour"], 6),
|
||||
"ambient_hours": round(best["ambient_hours"], 6),
|
||||
},
|
||||
"evaluation": {
|
||||
"positive_dataset": positive_mode,
|
||||
"ambient_dataset": ambient_mode,
|
||||
"positive_tracks": len(positive_tracks),
|
||||
"ambient_tracks": len(ambient_tracks),
|
||||
"cooldown_slices": int(args.cooldown_slices),
|
||||
"positive_skip_slices": int(args.positive_skip_slices),
|
||||
"window_sizes": window_sizes,
|
||||
"cutoff_min": round(float(cutoffs[0]), 4),
|
||||
"cutoff_max": round(float(cutoffs[-1]), 4),
|
||||
"cutoff_step": float(args.cutoff_step),
|
||||
},
|
||||
"per_window_best": best_by_window,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(output, indent=2) + "\n", encoding="utf-8")
|
||||
print(f"📝 Wrote calibration to {output_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -130,6 +130,73 @@ print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} fa
|
||||
EOF
|
||||
}
|
||||
|
||||
converter_from_dataset_api() {
|
||||
# shellcheck source=/dev/null
|
||||
source "${DATA_DIR}/.venv/bin/activate"
|
||||
|
||||
python - "${AUDIO16K_DIR}" <<-'EOF'
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
from datasets import load_dataset
|
||||
|
||||
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
x = np.clip(data, -1.0, 1.0)
|
||||
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||
|
||||
audioset_out = Path(sys.argv[1])
|
||||
|
||||
print(" AudioSet FLAC tarballs are unavailable; using Hugging Face datasets API instead.")
|
||||
dataset = load_dataset(
|
||||
"agkphysics/AudioSet",
|
||||
"balanced",
|
||||
split="train",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
audioset_bad = []
|
||||
ok = 0
|
||||
skipped = 0
|
||||
heartbeat_every = 250
|
||||
|
||||
for idx, sample in enumerate(dataset, start=1):
|
||||
try:
|
||||
video_id = str(sample.get("video_id") or f"audioset_{idx:06d}")
|
||||
outfile = audioset_out / f"{video_id}.wav"
|
||||
if outfile.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
audio = sample.get("audio") or {}
|
||||
y = np.asarray(audio.get("array"))
|
||||
sr = int(audio.get("sampling_rate") or 0)
|
||||
if y.size == 0 or sr <= 0:
|
||||
raise ValueError("missing decoded audio")
|
||||
if y.ndim > 1:
|
||||
y = np.mean(y, axis=-1)
|
||||
if sr != 16000:
|
||||
y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=16000)
|
||||
if y.size == 0:
|
||||
raise ValueError("empty audio")
|
||||
write_wav(outfile, y, 16000)
|
||||
ok += 1
|
||||
except Exception as exc:
|
||||
audioset_bad.append(f"{sample.get('video_id', idx)}:{exc}")
|
||||
|
||||
if idx == 1 or (idx % heartbeat_every) == 0:
|
||||
print(f" AudioSet API progress: {idx} clips processed (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})")
|
||||
|
||||
if audioset_bad:
|
||||
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
|
||||
|
||||
print(f" AudioSet complete via datasets API ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed)")
|
||||
EOF
|
||||
}
|
||||
|
||||
expected_filecount=$(get_total_filecount filecounts)
|
||||
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||
write_filecount=false
|
||||
@@ -139,7 +206,10 @@ if [ "${actual_filecount}" -ne 0 ] ; then
|
||||
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
|
||||
else
|
||||
dl=$(find_rev)
|
||||
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; }
|
||||
if [ -z "$dl" ] ; then
|
||||
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
||||
converter_from_dataset_api
|
||||
else
|
||||
rev=${dl%%,*}
|
||||
pattern=${dl##*,}
|
||||
|
||||
@@ -175,6 +245,7 @@ else
|
||||
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if ${write_filecount} ; then
|
||||
write_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||
|
||||
@@ -27,8 +27,11 @@ cd "${DATA_DIR}/training_datasets"
|
||||
|
||||
echo "***** Checking FMA *****"
|
||||
|
||||
AUDIO_URL="https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
||||
AUDIO_ZIPFILE="fma_xs.zip"
|
||||
AUDIO_URLS=(
|
||||
"https://os.unil.cloud.switch.ch/fma/fma_small.zip"
|
||||
"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
||||
)
|
||||
AUDIO_ZIPFILE="fma_small.zip"
|
||||
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||
AUDIO_DIR="fma"
|
||||
mkdir -p "${AUDIO_DIR}" || :
|
||||
@@ -81,6 +84,52 @@ EOF
|
||||
|
||||
}
|
||||
|
||||
extract_zip_with_python() {
|
||||
local zip_path="$1"
|
||||
local dest_dir="$2"
|
||||
|
||||
"${DATA_DIR}/.venv/bin/python" - "${zip_path}" "${dest_dir}" <<-'EOF'
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
zip_path = Path(sys.argv[1])
|
||||
dest_dir = Path(sys.argv[2])
|
||||
|
||||
if (not zip_path.exists()) or zip_path.stat().st_size == 0:
|
||||
raise SystemExit(f"Archive missing or empty: {zip_path}")
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
members = zf.infolist()
|
||||
size_gb = zip_path.stat().st_size / (1024 ** 3)
|
||||
print(f" Extracting {zip_path.name} ({len(members)} entries, {size_gb:.1f} GiB)...")
|
||||
for member in tqdm(members, desc=" FMA zip extract", unit="file"):
|
||||
zf.extract(member, dest_dir)
|
||||
EOF
|
||||
}
|
||||
|
||||
download_with_fallbacks() {
|
||||
local output="$1"
|
||||
shift
|
||||
local urls=( "$@" )
|
||||
local rc=1
|
||||
|
||||
for url in "${urls[@]}" ; do
|
||||
for attempt in 1 2 3 4 ; do
|
||||
curl -sfL "${url}" -o "${output}" && [ -s "${output}" ] && return 0
|
||||
rc=$?
|
||||
rm -f "${output}" || :
|
||||
if [ "${attempt}" -lt 4 ] ; then
|
||||
echo " Retry ${attempt}/3 after download failure"
|
||||
sleep $(( attempt * 2 ))
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
return "${rc}"
|
||||
}
|
||||
|
||||
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
||||
actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || :
|
||||
write_filecount=false
|
||||
@@ -92,13 +141,16 @@ else
|
||||
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
||||
echo " Downloading ${AUDIO_ZIPFILE}"
|
||||
curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}"
|
||||
download_with_fallbacks "${AUDIO_ZIP}" "${AUDIO_URLS[@]}" || {
|
||||
echo " Failed to download ${AUDIO_ZIPFILE} from all configured sources." >&2
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
rm -rf "${AUDIO_DIR}" || :
|
||||
mkdir "${AUDIO_DIR}"
|
||||
echo " Unzipping ${AUDIO_ZIPFILE}"
|
||||
unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}"
|
||||
echo " Extracting ${AUDIO_ZIPFILE}"
|
||||
extract_zip_with_python "${AUDIO_ZIP}" "${AUDIO_DIR}"
|
||||
fi
|
||||
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||
@@ -128,4 +180,3 @@ fi
|
||||
|
||||
echo " FMA complete"
|
||||
exit 0
|
||||
|
||||
|
||||
@@ -242,29 +242,7 @@ if [ ! -s "${MODEL_FILE}.json" ] ; then
|
||||
curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json"
|
||||
fi
|
||||
|
||||
# --- Dutch ONNX voices (single-speaker, used with --language=nl) ---
|
||||
# Working Dutch voices: pim, ronnie (nl_NL) and nathalie (nl_BE).
|
||||
# nl_NL-mls-medium is intentionally excluded (known Piper issue: outputs gibberish).
|
||||
HF_VOICES="https://huggingface.co/rhasspy/piper-voices/resolve/main"
|
||||
declare -a NL_VOICES=(
|
||||
"nl/nl_NL/pim/medium/nl_NL-pim-medium"
|
||||
"nl/nl_NL/ronnie/medium/nl_NL-ronnie-medium"
|
||||
"nl/nl_BE/nathalie/medium/nl_BE-nathalie-medium"
|
||||
)
|
||||
echo " ===== Checking Dutch Piper voices ====="
|
||||
for voice_path in "${NL_VOICES[@]}" ; do
|
||||
voice_name="$(basename "${voice_path}")"
|
||||
onnx_file="${VOICES_DIR}/${voice_name}.onnx"
|
||||
json_file="${VOICES_DIR}/${voice_name}.onnx.json"
|
||||
if [ ! -f "${onnx_file}" ] ; then
|
||||
echo " Downloading ${voice_name}.onnx"
|
||||
curl -sfL "${HF_VOICES}/${voice_path}.onnx?download=true" -o "${onnx_file}"
|
||||
fi
|
||||
if [ ! -f "${json_file}" ] ; then
|
||||
echo " Downloading ${voice_name}.onnx.json"
|
||||
curl -sfL "${HF_VOICES}/${voice_path}.onnx.json?download=true" -o "${json_file}"
|
||||
fi
|
||||
done
|
||||
echo " Non-English Piper voices will be downloaded on demand for the selected language."
|
||||
|
||||
${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu=""
|
||||
echo " ===== Installing onnxruntime${onnxgpu} ====="
|
||||
|
||||
@@ -317,6 +317,7 @@ fi
|
||||
|
||||
TRAINING_DONE="false"
|
||||
|
||||
echo "🏋️ Starting model training and TFLite export (this is the longest stage)…"
|
||||
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
|
||||
echo "✅ Training complete (GPU path)."
|
||||
TRAINING_DONE="true"
|
||||
@@ -386,12 +387,24 @@ if [ "${TRAINING_DONE}" != "true" ]; then
|
||||
fi
|
||||
|
||||
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||
calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json"
|
||||
|
||||
if [ ! -f "${source_path}" ] ; then
|
||||
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🎯 Calibrating detector settings for on-device use…"
|
||||
if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \
|
||||
--training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \
|
||||
--model "${source_path}" \
|
||||
--output "${calibration_path}"; then
|
||||
echo "✅ Detector calibration complete."
|
||||
else
|
||||
echo "⚠️ Detector calibration failed; packaging with default detector settings."
|
||||
rm -f "${calibration_path}" || :
|
||||
fi
|
||||
|
||||
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
||||
@@ -404,24 +417,49 @@ tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||
cp "${source_path}" "${tflite_path}"
|
||||
|
||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||
cat <<-EOF > "${json_path}"
|
||||
{
|
||||
export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}"
|
||||
echo "📦 Packaging final model artifacts…"
|
||||
"${PYTHON_BIN:-python}" - <<'PY'
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
json_path = Path(os.environ["JSON_PATH"])
|
||||
calibration_path = Path(os.environ.get("CALIBRATION_PATH", ""))
|
||||
language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower()
|
||||
probability_cutoff = 0.97
|
||||
sliding_window_size = 5
|
||||
|
||||
if calibration_path.exists():
|
||||
try:
|
||||
calibration = json.loads(calibration_path.read_text(encoding="utf-8"))
|
||||
probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff))
|
||||
sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size))
|
||||
print(
|
||||
f"🎯 Using calibrated detector settings: "
|
||||
f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}"
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.")
|
||||
|
||||
meta = {
|
||||
"type": "micro",
|
||||
"wake_word": "${WAKE_WORD_TITLE}",
|
||||
"wake_word": os.environ["WAKE_WORD_TITLE"],
|
||||
"author": "Tater Totterson",
|
||||
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
||||
"model": "${tflite_filename}",
|
||||
"trained_languages": ["en"],
|
||||
"model": os.environ["TFLITE_FILENAME"],
|
||||
"trained_languages": [language],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"sliding_window_size": 5,
|
||||
"probability_cutoff": round(probability_cutoff, 2),
|
||||
"sliding_window_size": sliding_window_size,
|
||||
"feature_step_size": 10,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
"minimum_esphome_version": "2024.7.0",
|
||||
},
|
||||
}
|
||||
}
|
||||
EOF
|
||||
json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8")
|
||||
PY
|
||||
|
||||
echo "Name: ${WAKE_WORD_TITLE}"
|
||||
echo "Model: ${tflite_path}"
|
||||
|
||||
@@ -22,7 +22,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/
|
||||
# Root-level entrypoints
|
||||
COPY --chown=root:root --chmod=0755 \
|
||||
train_wake_word \
|
||||
run_recorder.sh \
|
||||
run.sh \
|
||||
trainer_server.py \
|
||||
requirements.txt \
|
||||
/root/mww-scripts/
|
||||
@@ -37,4 +37,4 @@ RUN chmod -R a+x /root/mww-scripts/cli
|
||||
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
||||
|
||||
# trainer server
|
||||
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"]
|
||||
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run.sh"]
|
||||
|
||||
@@ -8,7 +8,7 @@ DATA_DIR="${DATA_DIR:-/data}"
|
||||
HOST="${REC_HOST:-0.0.0.0}"
|
||||
PORT="${REC_PORT:-8888}"
|
||||
|
||||
# Keep recorder deps separate from training venv
|
||||
# Keep trainer UI deps separate from the training venv
|
||||
VENV_DIR="${DATA_DIR}/.recorder-venv"
|
||||
PY="${VENV_DIR}/bin/python"
|
||||
PIP="${PY} -m pip"
|
||||
@@ -42,6 +42,12 @@ PIPER_VOICES_ROOT_URL = os.environ.get(
|
||||
"https://huggingface.co/rhasspy/piper-voices/resolve/main",
|
||||
)
|
||||
PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900"))
|
||||
PIPER_CATALOG_CACHE_FILE = Path(
|
||||
os.environ.get(
|
||||
"PIPER_CATALOG_CACHE_FILE",
|
||||
str(ROOT_DIR / ".cache" / "piper_voices_catalog.json"),
|
||||
)
|
||||
).resolve()
|
||||
|
||||
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
||||
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
||||
@@ -177,6 +183,27 @@ def _fetch_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||
return data if isinstance(data, dict) else None
|
||||
|
||||
|
||||
def _read_cached_piper_catalog_file() -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
if not PIPER_CATALOG_CACHE_FILE.exists():
|
||||
return None
|
||||
data = json.loads(PIPER_CATALOG_CACHE_FILE.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _write_cached_piper_catalog_file(data: Dict[str, Any]):
|
||||
try:
|
||||
PIPER_CATALOG_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
PIPER_CATALOG_CACHE_FILE.write_text(
|
||||
json.dumps(data, ensure_ascii=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||
now = time.time()
|
||||
with PIPER_CATALOG_LOCK:
|
||||
@@ -185,6 +212,8 @@ def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||
if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS:
|
||||
return cached
|
||||
|
||||
disk_cached = _read_cached_piper_catalog_file()
|
||||
|
||||
try:
|
||||
fresh = _fetch_piper_catalog()
|
||||
except Exception:
|
||||
@@ -194,8 +223,14 @@ def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||
if fresh is not None:
|
||||
PIPER_CATALOG_CACHE["entries"] = fresh
|
||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||
_write_cached_piper_catalog_file(fresh)
|
||||
return fresh
|
||||
if PIPER_CATALOG_CACHE.get("entries") is None:
|
||||
if PIPER_CATALOG_CACHE.get("entries") is not None:
|
||||
return PIPER_CATALOG_CACHE.get("entries")
|
||||
if disk_cached is not None:
|
||||
PIPER_CATALOG_CACHE["entries"] = disk_cached
|
||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||
return disk_cached
|
||||
PIPER_CATALOG_CACHE["entries"] = {}
|
||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||
return PIPER_CATALOG_CACHE.get("entries")
|
||||
|
||||
Reference in New Issue
Block a user