Files
microWakeWord-Trainer-Nvidi…/trainer_server.py
2026-04-14 22:55:49 -05:00

1055 lines
35 KiB
Python

# trainer_server.py
import io
import os
import re
import json
import shutil
import subprocess
import tempfile
import threading
import time
import wave
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from urllib.request import Request, urlopen
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
ROOT_DIR = Path(__file__).resolve().parent
# In Docker CLI world, DATA_DIR should be /data
DATA_DIR = Path(os.environ.get("DATA_DIR", "/data")).resolve()
# UI files live next to this script by default
STATIC_DIR = Path(os.environ.get("STATIC_DIR", str(ROOT_DIR / "static"))).resolve()
# Personal samples MUST land in /data/personal_samples for your CLI pipeline
PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve()
# CLI folder inside repo
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator"
PIPER_VOICES_DIR = PIPER_ROOT / "voices"
PIPER_VOICES_INDEX_URL = os.environ.get(
"PIPER_VOICES_INDEX_URL",
"https://huggingface.co/rhasspy/piper-voices/raw/main/voices.json",
)
PIPER_VOICES_ROOT_URL = os.environ.get(
"PIPER_VOICES_ROOT_URL",
"https://huggingface.co/rhasspy/piper-voices/resolve/main",
)
PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900"))
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
TRAIN_CMD = os.environ.get(
"TRAIN_CMD",
f"source '{DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir '{DATA_DIR}'"
)
DEFAULT_LANGUAGE = os.environ.get("MWW_LANGUAGE", "en")
TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10"))
SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1"))
TARGET_SAMPLE_RATE = 16000
TARGET_CHANNELS = 1
TARGET_SAMPLE_WIDTH_BYTES = 2
# Tail lines shown to UI
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
# Safety cap for reads (bytes) to avoid giant file reads
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB
app = FastAPI(title="microWakeWord Personal Samples")
STATIC_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
def safe_name(raw: str) -> str:
s = (raw or "").strip().lower()
s = re.sub(r"\s+", "_", s)
s = re.sub(r"[^a-z0-9_]+", "", s)
s = re.sub(r"^_+|_+$", "", s)
return s or "wakeword"
STATE: Dict[str, Any] = {
"raw_phrase": None,
"safe_word": None,
"language": DEFAULT_LANGUAGE,
"speakers_total": SPEAKERS_TOTAL_DEFAULT,
"takes_per_speaker": TAKES_PER_SPEAKER_DEFAULT,
"takes_received": 0,
"takes": [],
"training": {
"running": False,
"exit_code": None,
"log_lines": [], # legacy in-memory tail (kept, but not relied on)
"log_path": None, # path to recorder_training.log
"safe_word": None,
# prevent UI duplication when UI appends:
"last_sent_tail": [], # last tail snapshot (list of lines)
"last_log_size": 0, # detect truncation
},
}
STATE_LOCK = threading.Lock()
SAMPLES_LOCK = threading.Lock()
PIPER_CATALOG_LOCK = threading.Lock()
PIPER_CATALOG_CACHE: Dict[str, Any] = {
"fetched_at": 0.0,
"entries": None,
}
def _reset_personal_samples_dir():
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
for p in PERSONAL_DIR.glob("*.wav"):
try:
p.unlink()
except Exception:
pass
def _list_personal_samples() -> List[str]:
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
return sorted(p.name for p in PERSONAL_DIR.glob("*.wav"))
def _sync_personal_samples_state() -> List[str]:
takes = _list_personal_samples()
with STATE_LOCK:
STATE["takes"] = takes
STATE["takes_received"] = len(takes)
return takes
def _registered_language_family(language: Dict[str, Any]) -> str:
family = str(language.get("family") or "").strip().lower()
if family:
return family
code = str(language.get("code") or "").strip()
return code.split("_", 1)[0].lower() if code else ""
def _register_language(
languages: Dict[str, Dict[str, Any]],
*,
family: str,
name: str,
region: str = "",
count: int = 1,
):
if not family:
return
entry = languages.setdefault(
family,
{
"code": family,
"label": f"{name} ({family})",
"name": name,
"voice_count": 0,
"regions": [],
},
)
entry["voice_count"] += count
if region and region not in entry["regions"]:
entry["regions"].append(region)
def _fetch_piper_catalog() -> Optional[Dict[str, Any]]:
req = Request(
PIPER_VOICES_INDEX_URL,
headers={"User-Agent": "microWakeWord-Trainer/1.0"},
)
with urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data if isinstance(data, dict) else None
def _load_piper_catalog() -> Optional[Dict[str, Any]]:
now = time.time()
with PIPER_CATALOG_LOCK:
cached = PIPER_CATALOG_CACHE.get("entries")
fetched_at = float(PIPER_CATALOG_CACHE.get("fetched_at") or 0.0)
if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS:
return cached
try:
fresh = _fetch_piper_catalog()
except Exception:
fresh = None
with PIPER_CATALOG_LOCK:
if fresh is not None:
PIPER_CATALOG_CACHE["entries"] = fresh
PIPER_CATALOG_CACHE["fetched_at"] = now
return fresh
if PIPER_CATALOG_CACHE.get("entries") is None:
PIPER_CATALOG_CACHE["entries"] = {}
PIPER_CATALOG_CACHE["fetched_at"] = now
return PIPER_CATALOG_CACHE.get("entries")
def _available_languages() -> List[Dict[str, Any]]:
languages: Dict[str, Dict[str, Any]] = {
"en": {
"code": "en",
"label": "English (en)",
"name": "English",
"voice_count": 1,
"regions": [],
}
}
if PIPER_VOICES_DIR.exists():
for meta_path in sorted(PIPER_VOICES_DIR.glob("*.onnx.json")):
try:
data = json.loads(meta_path.read_text(encoding="utf-8"))
except Exception:
continue
language = data.get("language") or {}
family = _registered_language_family(language)
if not family or family == "en":
continue
name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip()
region = str(language.get("country_english") or language.get("region") or "").strip()
_register_language(languages, family=family, name=name, region=region, count=1)
catalog = _load_piper_catalog() or {}
for entry in catalog.values():
if not isinstance(entry, dict):
continue
language = entry.get("language") or {}
family = _registered_language_family(language)
if not family or family == "en":
continue
name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip()
region = str(language.get("country_english") or language.get("region") or "").strip()
_register_language(languages, family=family, name=name, region=region, count=0)
ordered = [languages["en"]]
ordered.extend(
sorted(
(entry for code, entry in languages.items() if code != "en"),
key=lambda entry: (entry["name"].lower(), entry["code"]),
)
)
return ordered
def _normalize_language(language: str | None) -> str:
requested = (language or DEFAULT_LANGUAGE).strip().lower() or DEFAULT_LANGUAGE
available_codes = {item["code"] for item in _available_languages()}
if requested in available_codes:
return requested
if DEFAULT_LANGUAGE in available_codes:
return DEFAULT_LANGUAGE
return "en"
def _catalog_voice_files(language_family: str) -> List[Tuple[str, str]]:
if not language_family or language_family == "en":
return []
downloads: Dict[str, str] = {}
catalog = _load_piper_catalog() or {}
for entry in catalog.values():
if not isinstance(entry, dict):
continue
language = entry.get("language") or {}
family = _registered_language_family(language)
if family != language_family:
continue
files = entry.get("files") or {}
for rel_path in files.keys():
if not isinstance(rel_path, str):
continue
if not (rel_path.endswith(".onnx") or rel_path.endswith(".onnx.json")):
continue
downloads[Path(rel_path).name] = f"{PIPER_VOICES_ROOT_URL}/{rel_path}?download=true"
return sorted(downloads.items(), key=lambda item: item[0])
def _download_to_path(url: str, dest_path: Path):
dest_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp")
req = Request(url, headers={"User-Agent": "microWakeWord-Trainer/1.0"})
with urlopen(req, timeout=60) as resp, open(tmp_path, "wb") as out:
shutil.copyfileobj(resp, out)
tmp_path.replace(dest_path)
def _ensure_non_english_language_voices(language_family: str, log) -> Dict[str, int]:
downloads = _catalog_voice_files(language_family)
local_voices = sorted(PIPER_VOICES_DIR.glob(f"{language_family}_*.onnx")) if PIPER_VOICES_DIR.exists() else []
if not downloads:
if local_voices:
log(f"===== Piper Voices ({language_family}) =====")
log(f"→ Using {len(local_voices)} installed voice(s) for language '{language_family}'")
return {
"downloaded_files": 0,
"existing_files": len(local_voices),
"voices": len(local_voices),
}
raise RuntimeError(
f"No Piper ONNX voices found for language '{language_family}' in the upstream catalog."
)
PIPER_VOICES_DIR.mkdir(parents=True, exist_ok=True)
downloaded_files = 0
existing_files = 0
voice_names = sorted(name for name, _ in downloads if name.endswith(".onnx"))
log(f"===== Piper Voices ({language_family}) =====")
log(f"→ Ensuring {len(voice_names)} voice(s) for language '{language_family}'")
for file_name, url in downloads:
dest_path = PIPER_VOICES_DIR / file_name
if dest_path.exists() and dest_path.stat().st_size > 0:
existing_files += 1
continue
log(f"→ Downloading {file_name}")
_download_to_path(url, dest_path)
downloaded_files += 1
log(
f"✓ Piper voices ready for '{language_family}' "
f"({downloaded_files} file(s) downloaded, {existing_files} already present)"
)
return {
"downloaded_files": downloaded_files,
"existing_files": existing_files,
"voices": len(voice_names),
}
def _find_ffmpeg() -> Optional[str]:
candidates = [
shutil.which("ffmpeg"),
"/usr/bin/ffmpeg",
"/usr/local/bin/ffmpeg",
"/opt/homebrew/bin/ffmpeg",
"/opt/homebrew/opt/ffmpeg@7/bin/ffmpeg",
"/opt/homebrew/opt/ffmpeg/bin/ffmpeg",
]
for candidate in candidates:
if candidate and Path(candidate).exists():
return candidate
return None
def _inspect_wav_bytes(data: bytes) -> Optional[Dict[str, Any]]:
try:
with wave.open(io.BytesIO(data), "rb") as wf:
frames = wf.getnframes()
rate = wf.getframerate()
duration = (frames / rate) if rate else 0.0
return {
"container": "wav",
"sample_rate": rate,
"channels": wf.getnchannels(),
"sample_width_bits": wf.getsampwidth() * 8,
"compression": wf.getcomptype(),
"frames": frames,
"duration_s": round(duration, 3),
}
except Exception:
return None
def _is_target_wav(info: Optional[Dict[str, Any]]) -> bool:
return bool(
info
and info.get("container") == "wav"
and info.get("sample_rate") == TARGET_SAMPLE_RATE
and info.get("channels") == TARGET_CHANNELS
and info.get("sample_width_bits") == TARGET_SAMPLE_WIDTH_BYTES * 8
and info.get("compression") == "NONE"
and info.get("frames", 0) > 0
)
def _next_personal_sample_name(original_name: str) -> str:
current = _list_personal_samples()
next_index = 1
for name in current:
match = re.match(r"sample_(\d{4})", name)
if match:
next_index = max(next_index, int(match.group(1)) + 1)
stem = safe_name(Path(original_name or "sample").stem)
suffix = f"_{stem[:32]}" if stem and stem != "wakeword" else ""
return f"sample_{next_index:04d}{suffix}.wav"
def _format_hint_from_filename(original_name: str) -> Dict[str, Any]:
suffix = (Path(original_name or "").suffix or "").lower().lstrip(".")
return {
"container": suffix or "unknown",
"sample_rate": None,
"channels": None,
"sample_width_bits": None,
"compression": None,
"frames": None,
"duration_s": None,
}
def _normalize_audio_to_target_wav(data: bytes, original_name: str) -> bytes:
ffmpeg = _find_ffmpeg()
if not ffmpeg:
raise RuntimeError(
"ffmpeg is required to convert uploads that are not already 16 kHz mono 16-bit PCM WAV."
)
suffix = (Path(original_name or "").suffix or ".audio")
with tempfile.TemporaryDirectory(prefix="mww_upload_") as tmpdir:
src_path = Path(tmpdir) / f"source{suffix}"
dst_path = Path(tmpdir) / "normalized.wav"
src_path.write_bytes(data)
cmd = [
ffmpeg,
"-y",
"-i",
str(src_path),
"-vn",
"-ac",
str(TARGET_CHANNELS),
"-ar",
str(TARGET_SAMPLE_RATE),
"-c:a",
"pcm_s16le",
str(dst_path),
]
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0 or not dst_path.exists():
err = (proc.stderr or proc.stdout or "ffmpeg conversion failed").strip()
raise RuntimeError(err.splitlines()[-1] if err else "ffmpeg conversion failed")
return dst_path.read_bytes()
def _save_personal_sample(data: bytes, original_name: str, out_name: Optional[str] = None) -> Dict[str, Any]:
if not data:
raise ValueError("Empty or invalid audio file.")
original_info = _inspect_wav_bytes(data) or _format_hint_from_filename(original_name)
normalized = _is_target_wav(original_info)
final_bytes = data if normalized else _normalize_audio_to_target_wav(data, original_name)
final_info = _inspect_wav_bytes(final_bytes)
if not _is_target_wav(final_info):
raise ValueError("Uploaded audio could not be normalized to 16 kHz mono 16-bit PCM WAV.")
with SAMPLES_LOCK:
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
final_name = out_name or _next_personal_sample_name(original_name)
out_path = PERSONAL_DIR / final_name
out_path.write_bytes(final_bytes)
return {
"saved_as": final_name,
"converted": not normalized,
"original_name": original_name or final_name,
"detected_format": original_info,
"final_format": final_info,
"message": (
"Converted to 16 kHz mono 16-bit PCM WAV"
if not normalized
else "Already in the correct 16 kHz mono 16-bit PCM WAV format"
),
}
def _clear_training_log():
"""
Truncate recorder_training.log for a fresh session.
"""
log_path = DATA_DIR / "recorder_training.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "w", encoding="utf-8") as lf:
lf.write("================================================================================\n")
lf.write("===== New recorder session started =====\n")
lf.write("================================================================================\n")
lf.flush()
with STATE_LOCK:
STATE["training"]["log_path"] = str(log_path)
STATE["training"]["log_lines"] = []
STATE["training"]["last_sent_tail"] = []
STATE["training"]["last_log_size"] = 0
def _append_train_log(line: str):
line = (line or "").rstrip("\n")
with STATE_LOCK:
buf: List[str] = STATE["training"]["log_lines"]
buf.append(line)
if len(buf) > 250:
del buf[: (len(buf) - 250)]
def _title_from_phrase(raw_phrase: str) -> str:
s = re.sub(r"[^a-zA-Z0-9 ]+", " ", raw_phrase or "").strip()
s = re.sub(r"\s+", " ", s)
return s.title() if s else ""
def _run_streamed(
cmd: List[str],
cwd: Path,
log_path: Path,
header: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
) -> int:
if header:
_append_train_log(header)
_append_train_log("" + " ".join(cmd))
with open(log_path, "a", encoding="utf-8") as lf:
lf.write("\n" + ("=" * 80) + "\n")
if header:
lf.write(header + "\n")
lf.write("" + " ".join(cmd) + "\n")
lf.flush()
proc = subprocess.Popen(
cmd,
cwd=str(cwd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
assert proc.stdout is not None
for line in proc.stdout:
lf.write(line)
lf.flush()
_append_train_log(line)
return proc.wait()
def _ensure_training_venv(log_path: Path) -> None:
activate = DATA_DIR / ".venv" / "bin" / "activate"
if activate.exists():
_append_train_log("✅ Training venv found (skipping setup_python_venv)")
return
setup = CLI_DIR / "setup_python_venv"
if not setup.exists():
raise RuntimeError(f"Missing setup_python_venv at: {setup}")
rc = _run_streamed(
["bash", "-lc", f"cd '{DATA_DIR}' && '{setup}' --data-dir='{DATA_DIR}'"],
cwd=DATA_DIR,
log_path=log_path,
header="===== Ensuring Python venv (/data/.venv) =====",
)
if rc != 0:
raise RuntimeError(f"setup_python_venv failed (exit_code={rc})")
if not activate.exists():
raise RuntimeError(f"setup_python_venv finished, but {activate} is still missing")
def _ensure_training_datasets(log_path: Path) -> None:
setup = CLI_DIR / "setup_training_datasets"
if not setup.exists():
raise RuntimeError(f"Missing setup_training_datasets at: {setup}")
cleanup_arch = "true" if DATASET_CLEANUP_ARCHIVES else "false"
cleanup_inter = "true" if DATASET_CLEANUP_INTERMEDIATE else "false"
cmd = [
"bash",
"-lc",
(
f"cd '{DATA_DIR}' && "
f"'{setup}' "
f"--cleanup-archives='{cleanup_arch}' "
f"--cleanup-intermediate-files='{cleanup_inter}' "
f"--data-dir='{DATA_DIR}'"
),
]
rc = _run_streamed(
cmd,
cwd=DATA_DIR,
log_path=log_path,
header="===== Ensuring training datasets (setup_training_datasets) =====",
)
if rc != 0:
raise RuntimeError(f"setup_training_datasets failed (exit_code={rc})")
def _read_tail_lines(log_path: Path, max_lines: int) -> List[str]:
"""
Read the last N lines, bounded by TRAIN_LOG_MAX_BYTES.
Returns list of lines (no trailing newlines).
"""
if not log_path.exists():
return []
try:
size = log_path.stat().st_size
start = max(0, size - TRAIN_LOG_MAX_BYTES)
with open(log_path, "rb") as f:
f.seek(start)
data = f.read()
text = data.decode("utf-8", errors="replace")
lines = text.splitlines()
if len(lines) <= max_lines:
return lines
return lines[-max_lines:]
except Exception:
return []
def _compute_new_lines(prev_tail: List[str], new_tail: List[str]) -> List[str]:
"""
Given previous and current tail snapshots, return only the newly-added lines.
Works even if the tail window shifts.
"""
if not prev_tail:
return new_tail
# Try to find the largest suffix of prev_tail that matches a prefix of new_tail
max_k = min(len(prev_tail), len(new_tail))
for k in range(max_k, 0, -1):
if prev_tail[-k:] == new_tail[:k]:
return new_tail[k:]
# If no overlap, just return full new_tail (probably truncation or big jump)
return new_tail
# -------------------- output artifact normalization --------------------
def _find_latest_output_pair(output_dir: Path) -> Tuple[Optional[Path], Optional[Path]]:
"""
Find the most recently modified .tflite and its matching .json (same basename)
in output_dir. Falls back to newest .json if an exact match doesn't exist.
Returns (tflite_path, json_path) or (None, None).
"""
if not output_dir.exists():
return (None, None)
tflites = sorted(output_dir.glob("*.tflite"), key=lambda p: p.stat().st_mtime, reverse=True)
if not tflites:
return (None, None)
tfl = tflites[0]
js = tfl.with_suffix(".json")
if js.exists():
return (tfl, js)
jsons = sorted(output_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
return (tfl, jsons[0] if jsons else None)
def _deep_replace_strings(obj: Any, old: str, new: str) -> Any:
"""
Recursively replace occurrences of old in any string values with new.
"""
if isinstance(obj, str):
return obj.replace(old, new)
if isinstance(obj, list):
return [_deep_replace_strings(x, old, new) for x in obj]
if isinstance(obj, dict):
return {k: _deep_replace_strings(v, old, new) for k, v in obj.items()}
return obj
def _normalize_output_artifacts(safe_word: str, log_path: Path) -> None:
"""
Rename output artifacts to <safe_word>.tflite / <safe_word>.json
and patch the JSON so it references the renamed tflite.
Handles weird trainer names like ____r_.tflite by normalizing post-run.
"""
out_dir = DATA_DIR / "output"
tfl, js = _find_latest_output_pair(out_dir)
if not tfl:
_append_train_log(f"⚠️ No .tflite found in {out_dir}")
return
new_tfl = out_dir / f"{safe_word}.tflite"
new_js = out_dir / f"{safe_word}.json"
old_tfl_name = tfl.name
# Already normalized
if tfl.name == new_tfl.name and (js and js.name == new_js.name):
_append_train_log(f"✅ Output names already normalized: {new_tfl.name}")
return
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
def backup_if_exists(p: Path, suffix: str) -> None:
if p.exists():
bk = out_dir / f"{safe_word}.{ts}.bak{suffix}"
shutil.move(str(p), str(bk))
_append_train_log(f"↪️ Backed up existing {p.name}{bk.name}")
# Avoid clobbering existing target files (back them up)
if new_tfl.exists() and new_tfl.resolve() != tfl.resolve():
backup_if_exists(new_tfl, ".tflite")
if new_js.exists() and (not js or new_js.resolve() != js.resolve()):
backup_if_exists(new_js, ".json")
# Rename tflite
if tfl.resolve() != new_tfl.resolve():
new_tfl.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(tfl), str(new_tfl))
_append_train_log(f"✅ Renamed model: {old_tfl_name}{new_tfl.name}")
# Rename + patch json if present
if js and js.exists():
# Read JSON before move (safer if we want the old name)
try:
data = json.loads(js.read_text(encoding="utf-8"))
except Exception:
data = None
if js.resolve() != new_js.resolve():
shutil.move(str(js), str(new_js))
_append_train_log(f"✅ Renamed metadata: {js.name}{new_js.name}")
if data is not None:
patched = _deep_replace_strings(data, old_tfl_name, new_tfl.name)
# Patch common keys if present
for key in ("model", "model_file", "model_filename", "tflite", "tflite_file", "tflite_filename"):
if isinstance(patched, dict) and key in patched and isinstance(patched[key], str):
patched[key] = new_tfl.name
new_js.write_text(json.dumps(patched, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
_append_train_log(f"✅ Patched JSON to reference: {new_tfl.name}")
else:
_append_train_log("⚠️ No .json found to patch (model renamed only)")
# -------------------- training worker --------------------
def _run_training_background(safe_word: str, allow_no_personal: bool):
with STATE_LOCK:
raw_phrase = STATE.get("raw_phrase") or ""
language = STATE.get("language") or DEFAULT_LANGUAGE
wake_word_title = _title_from_phrase(raw_phrase)
with STATE_LOCK:
STATE["training"]["running"] = True
STATE["training"]["exit_code"] = None
STATE["training"]["log_lines"] = []
STATE["training"]["safe_word"] = safe_word
STATE["training"]["last_sent_tail"] = []
STATE["training"]["last_log_size"] = 0
log_path = Path(str(DATA_DIR / "recorder_training.log"))
STATE["training"]["log_path"] = str(log_path)
_append_train_log("================================================================================")
_append_train_log("===== Recorder Training Run =====")
_append_train_log("================================================================================")
try:
with open(log_path, "a", encoding="utf-8") as lf:
lf.write("\n" + ("=" * 80) + "\n")
lf.write("===== Recorder Training Run =====\n")
lf.write(("=" * 80) + "\n")
lf.flush()
except Exception:
pass
try:
_ensure_training_venv(log_path)
_ensure_training_datasets(log_path)
if language != "en":
_ensure_non_english_language_voices(language, _append_train_log)
if wake_word_title:
cmd_str = f"{TRAIN_CMD} --language='{language}' '{safe_word}' '{wake_word_title}'"
else:
cmd_str = f"{TRAIN_CMD} --language='{language}' '{safe_word}'"
env = os.environ.copy()
env["MWW_ALLOW_NO_PERSONAL"] = "true" if allow_no_personal else "false"
_append_train_log("===== Training (train_wake_word) =====")
_append_train_log(f"→ Running: {cmd_str}")
with open(log_path, "a", encoding="utf-8") as lf:
proc = subprocess.Popen(
["bash", "-lc", cmd_str],
cwd=str(DATA_DIR),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
assert proc.stdout is not None
for line in proc.stdout:
lf.write(line)
lf.flush()
_append_train_log(line)
rc = proc.wait()
_append_train_log(f"✓ Training finished (exit_code={rc})")
with STATE_LOCK:
STATE["training"]["exit_code"] = rc
# Normalize output artifact names on success
if rc == 0:
_normalize_output_artifacts(safe_word, log_path)
except Exception as e:
_append_train_log(f"✗ Training crashed: {e!r}")
with STATE_LOCK:
STATE["training"]["exit_code"] = 999
finally:
with STATE_LOCK:
STATE["training"]["running"] = False
@app.get("/", response_class=HTMLResponse)
def index():
html_path = STATIC_DIR / "index.html"
if not html_path.exists():
return HTMLResponse(
"<h3>Missing UI</h3><p>Create <code>static/index.html</code>.</p>",
status_code=500,
)
return HTMLResponse(html_path.read_text(encoding="utf-8"))
@app.post("/api/start_session")
def start_session(payload: Dict[str, Any]):
raw = (payload.get("phrase") or "").strip()
if not raw:
return JSONResponse({"ok": False, "error": "phrase is required"}, status_code=400)
safe = safe_name(raw)
speakers_total = int(payload.get("speakers_total") or SPEAKERS_TOTAL_DEFAULT)
takes_per_speaker = int(payload.get("takes_per_speaker") or TAKES_PER_SPEAKER_DEFAULT)
language = _normalize_language(payload.get("language"))
available_languages = _available_languages()
speakers_total = max(1, min(10, speakers_total))
takes_per_speaker = max(1, min(50, takes_per_speaker))
with STATE_LOCK:
STATE["raw_phrase"] = raw
STATE["safe_word"] = safe
STATE["language"] = language
STATE["speakers_total"] = speakers_total
STATE["takes_per_speaker"] = takes_per_speaker
takes = _sync_personal_samples_state()
# Always wipe log on start_session (even if same wakeword)
_clear_training_log()
return {
"ok": True,
"raw_phrase": raw,
"safe_word": safe,
"language": language,
"speakers_total": speakers_total,
"takes_per_speaker": takes_per_speaker,
"takes_total": speakers_total * takes_per_speaker,
"takes_received": len(takes),
"takes": takes,
"available_languages": available_languages,
"personal_dir": str(PERSONAL_DIR),
"data_dir": str(DATA_DIR),
}
@app.get("/api/session")
def get_session():
takes = _sync_personal_samples_state()
available_languages = _available_languages()
with STATE_LOCK:
current_language = _normalize_language(STATE["language"])
STATE["language"] = current_language
return {
"ok": True,
"raw_phrase": STATE["raw_phrase"],
"safe_word": STATE["safe_word"],
"language": current_language,
"speakers_total": STATE["speakers_total"],
"takes_per_speaker": STATE["takes_per_speaker"],
"takes_received": len(takes),
"takes": list(takes),
"training": dict(STATE["training"]),
"available_languages": available_languages,
}
@app.post("/api/upload_take")
async def upload_take(
speaker_index: int = Form(...),
take_index: int = Form(...),
file: UploadFile = File(...),
):
with STATE_LOCK:
safe_word = STATE["safe_word"]
speakers_total = int(STATE["speakers_total"])
takes_per_speaker = int(STATE["takes_per_speaker"])
if not safe_word:
return JSONResponse({"ok": False, "error": "No active session. Call /api/start_session first."}, status_code=400)
if speaker_index < 1 or speaker_index > speakers_total:
return JSONResponse({"ok": False, "error": f"speaker_index must be 1..{speakers_total}"}, status_code=400)
if take_index < 1 or take_index > takes_per_speaker:
return JSONResponse({"ok": False, "error": f"take_index must be 1..{takes_per_speaker}"}, status_code=400)
out_name = f"speaker{speaker_index:02d}_take{take_index:02d}.wav"
data = await file.read()
try:
result = _save_personal_sample(data, file.filename or out_name, out_name=out_name)
except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
takes = _sync_personal_samples_state()
return {"ok": True, **result, "takes_received": len(takes)}
@app.post("/api/upload_personal_sample")
async def upload_personal_sample(file: UploadFile = File(...)):
with STATE_LOCK:
safe_word = STATE["safe_word"]
if not safe_word:
return JSONResponse({"ok": False, "error": "No active session. Call /api/start_session first."}, status_code=400)
data = await file.read()
try:
result = _save_personal_sample(data, file.filename or "sample")
except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
takes = _sync_personal_samples_state()
return {"ok": True, **result, "takes_received": len(takes)}
@app.post("/api/train")
def train_now(payload: Dict[str, Any] = None):
payload = payload or {}
allow_no_personal = bool(payload.get("allow_no_personal", False))
with STATE_LOCK:
safe_word = STATE["safe_word"]
takes_received = int(STATE["takes_received"])
speakers_total = int(STATE["speakers_total"])
takes_per_speaker = int(STATE["takes_per_speaker"])
training_running = bool(STATE["training"]["running"])
takes_total = speakers_total * takes_per_speaker
if training_running:
return JSONResponse({"ok": False, "error": "Training already running"}, status_code=400)
if not safe_word:
return JSONResponse({"ok": False, "error": "No active session"}, status_code=400)
if takes_received == 0 and not allow_no_personal:
return JSONResponse(
{
"ok": False,
"error": "No personal voice samples uploaded yet.",
"code": "NO_PERSONAL_SAMPLES",
"message": "You can train without personal voices, or upload samples first.",
"takes_total": takes_total,
},
status_code=400,
)
t = threading.Thread(target=_run_training_background, args=(safe_word, allow_no_personal), daemon=True)
t.start()
return {
"ok": True,
"started": True,
"safe_word": safe_word,
"personal_samples_used": takes_received > 0,
"allow_no_personal": allow_no_personal,
}
@app.get("/api/train_status")
def train_status():
"""
Return only NEW lines since last poll (prevents UI duplication spam even if UI appends).
"""
with STATE_LOCK:
tr = dict(STATE["training"])
log_path_str = tr.get("log_path")
prev_tail = list(STATE["training"].get("last_sent_tail") or [])
prev_size = int(STATE["training"].get("last_log_size") or 0)
new_lines: List[str] = []
full_tail: List[str] = []
size_now = 0
if log_path_str:
p = Path(log_path_str)
if p.exists():
try:
size_now = int(p.stat().st_size)
except Exception:
size_now = 0
# If file was truncated/cleared, reset history
if size_now < prev_size:
prev_tail = []
full_tail = _read_tail_lines(p, TRAIN_LOG_TAIL_LINES)
new_lines = _compute_new_lines(prev_tail, full_tail)
# Save snapshot for next poll
with STATE_LOCK:
STATE["training"]["last_sent_tail"] = full_tail
STATE["training"]["last_log_size"] = size_now
tr["log_text"] = "\n".join(new_lines) # ONLY new lines
tr["log_tail_preview"] = "\n".join(full_tail) # optional: handy for debugging
tr["log_lines"] = full_tail
return {"ok": True, "training": tr}
@app.post("/api/reset_recordings")
def reset_recordings():
_reset_personal_samples_dir()
takes = _sync_personal_samples_state()
return {"ok": True, "takes_received": len(takes), "takes": takes}