diff --git a/cli/run_generator_with_progress.py b/cli/run_generator_with_progress.py new file mode 100644 index 0000000..61d0cf2 --- /dev/null +++ b/cli/run_generator_with_progress.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +import argparse +import queue +import subprocess +import sys +import threading +from pathlib import Path + + +def _model_args(generator_args): + values = [] + for idx, arg in enumerate(generator_args): + if arg == "--model" and idx + 1 < len(generator_args): + values.append(generator_args[idx + 1]) + return values + + +def _is_onnx_run(generator_args): + return any(str(value).endswith(".onnx") for value in _model_args(generator_args)) + + +def _format_line(line): + if line.startswith("DEBUG:piper.voice:"): + return None + for prefix in ("DEBUG:__main__:", "INFO:__main__:", "WARNING:__main__:", "ERROR:__main__:"): + if line.startswith(prefix): + return " " + line[len(prefix):].strip() + return line + + +def _reader(stdout, sink): + try: + for raw in stdout: + sink.put(raw.rstrip("\n")) + finally: + sink.put(None) + + +def _progress_step(max_samples): + if max_samples <= 20: + return 1 + if max_samples <= 100: + return 5 + return 10 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--generator", required=True) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--max-samples", required=True, type=int) + parser.add_argument("generator_args", nargs=argparse.REMAINDER) + args = parser.parse_args() + + generator_args = list(args.generator_args) + if generator_args and generator_args[0] == "--": + generator_args = generator_args[1:] + + cmd = [sys.executable, args.generator, *generator_args] + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + assert proc.stdout is not None + + line_queue = queue.Queue() + reader = threading.Thread(target=_reader, args=(proc.stdout, line_queue), daemon=True) + reader.start() + + output_dir = Path(args.output_dir) + use_sample_progress = _is_onnx_run(generator_args) + step = _progress_step(args.max_samples) + last_reported = 0 + stream_done = False + + while not stream_done or proc.poll() is None: + try: + line = line_queue.get(timeout=0.2) + except queue.Empty: + line = None + + if line is None: + if not stream_done and not line_queue.empty(): + continue + stream_done = proc.poll() is not None or stream_done + else: + formatted = _format_line(line) + if formatted: + print(formatted, flush=True) + + if use_sample_progress: + current = len(list(output_dir.glob("*.wav"))) + should_report = current > last_reported and ( + current >= args.max_samples + or current - last_reported >= step + ) + if should_report: + print(f" Generated {current}/{args.max_samples} samples...", flush=True) + last_reported = current + + rc = proc.wait() + final_count = len(list(output_dir.glob("*.wav"))) if use_sample_progress else 0 + if use_sample_progress and final_count > last_reported: + print(f" Generated {final_count}/{args.max_samples} samples...", flush=True) + return rc + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/cli/wake_word_sample_generator b/cli/wake_word_sample_generator index 4166e75..6dd4be0 100755 --- a/cli/wake_word_sample_generator +++ b/cli/wake_word_sample_generator @@ -138,11 +138,16 @@ export GRPC_VERBOSITY=ERROR echo " Generating samples" rm -rf "${SAMPLES_DIR}" || : mkdir -p "${SAMPLES_DIR}" || : -"${PSG}/generate_samples.py" "${WAKE_WORD}" \ +python "${PROGDIR}/run_generator_with_progress.py" \ + --generator "${PSG}/generate_samples.py" \ + --output-dir "${SAMPLES_DIR}" \ + --max-samples ${SAMPLES} \ + -- \ + "${WAKE_WORD}" \ "${MODEL_ARGS[@]}" \ --max-samples ${SAMPLES} \ --batch-size ${BATCH_SIZE} \ - --output-dir "${SAMPLES_DIR}" 2>&1 | sed -r -e "s/(DEBUG|INFO):__main__:/ /g" + --output-dir "${SAMPLES_DIR}" generated_files=$(find "${SAMPLES_DIR}" -name '*.wav' | wc -l) if [ "${generated_files}" -ne "${SAMPLES}" ] ; then diff --git a/dockerfile b/dockerfile index 5778ead..bf3bfeb 100644 --- a/dockerfile +++ b/dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && rm -rf /var/lib/apt/lists/* \ && mkdir -p /data -# Recorder port +# Trainer UI port EXPOSE 8789 # Script root @@ -23,7 +23,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/ COPY --chown=root:root --chmod=0755 \ train_wake_word \ run_recorder.sh \ - recorder_server.py \ + trainer_server.py \ requirements.txt \ /root/mww-scripts/ @@ -33,8 +33,8 @@ COPY --chown=root:root cli/ /root/mww-scripts/cli/ # Make all CLI scripts executable (avoids "Permission denied") RUN chmod -R a+x /root/mww-scripts/cli -# Static UI for recorder +# Static UI for trainer COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html -# recorder server +# trainer server CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"] diff --git a/run_recorder.sh b/run_recorder.sh index 9ac94c5..75d1ec9 100644 --- a/run_recorder.sh +++ b/run_recorder.sh @@ -18,7 +18,7 @@ FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}" UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}" PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}" -echo "microWakeWord Recorder (Docker)" +echo "microWakeWord Trainer UI (Docker)" echo "-> ROOTDIR: ${ROOTDIR}" echo "-> DATA_DIR: ${DATA_DIR}" echo "-> URL: http://localhost:${PORT}/" @@ -26,10 +26,10 @@ echo "-> URL: http://localhost:${PORT}/" mkdir -p "${DATA_DIR}" # ----------------------------- -# Recorder venv (separate) +# Trainer UI venv (separate) # ----------------------------- if [[ ! -x "${PY}" ]]; then - echo "Creating recorder venv: ${VENV_DIR}" + echo "Creating trainer UI venv: ${VENV_DIR}" python3 -m venv "${VENV_DIR}" fi @@ -37,7 +37,7 @@ fi source "${VENV_DIR}/bin/activate" if [[ ! -f "${PIN_FILE}" ]]; then - echo "Installing pinned recorder deps" + echo "Installing pinned trainer UI deps" ${PIP} install -U pip setuptools wheel ${PIP} install \ "fastapi==${FASTAPI_VERSION}" \ @@ -45,20 +45,20 @@ if [[ ! -f "${PIN_FILE}" ]]; then "python-multipart==${PY_MULTIPART_VERSION}" touch "${PIN_FILE}" else - echo "Reusing existing recorder venv (no upgrades)" + echo "Reusing existing trainer UI venv (no upgrades)" fi # ----------------------------- -# Recorder server env +# Trainer server env # ----------------------------- export DATA_DIR="${DATA_DIR}" export STATIC_DIR="${ROOTDIR}/static" export PERSONAL_DIR="${DATA_DIR}/personal_samples" -# IMPORTANT: leave training venv creation to /api/train inside recorder_server.py +# IMPORTANT: leave training venv creation to /api/train inside trainer_server.py # but still set TRAIN_CMD so the server knows how to invoke training once ready export TRAIN_CMD="source '${DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir='${DATA_DIR}'" echo "Launching uvicorn on ${HOST}:${PORT}" cd "${ROOTDIR}" -exec "${VENV_DIR}/bin/uvicorn" recorder_server:app --host "${HOST}" --port "${PORT}" \ No newline at end of file +exec "${VENV_DIR}/bin/uvicorn" trainer_server:app --host "${HOST}" --port "${PORT}" diff --git a/static/index.html b/static/index.html index 4b57012..d0a668a 100644 --- a/static/index.html +++ b/static/index.html @@ -3,71 +3,71 @@ - microWakeWord Recorder + microWakeWord Personal Samples @@ -163,656 +340,584 @@
-

🎙️ microWakeWord Personal Recorder

-

Enter a wake word, test TTS pronunciation, then record takes. Recording starts when you speak and stops after silence.

+

microWakeWord Personal Samples

+

Start a session, upload your own recorded voice samples, and the app will validate or convert them into the training format used by the existing pipeline.

-
+
- + No session
-
+
- - - Speaker: -
- -
- Advanced (if it’s too sensitive / not sensitive enough) -
- - - -
-
-
-
- - - +
+
+
+

Optional Personal Samples

+

Personal samples are optional. You can train with TTS only, or upload your own audio here and it will be saved into personal_samples/ as 16 kHz mono 16-bit PCM WAV.

+
Idle
-
-
- Mic level +
+
+ Select one or many files +

WAV, MP3, M4A, FLAC, OGG, AAC, OPUS, and WEBM are all fine when ffmpeg is available. Files already in the correct format are kept as-is.

+
+ + No files selected
-
+
+ + +
-

- Speaker: - / - - Waiting -

+
+
+ No upload in progress + 0% +
+
+
+
+
When you upload, each file is checked and converted only if needed before it is written into personal_samples/.
+
-

- Take: 0 / 10 - Not recording -

+
+
+ Uploaded + 0 +
+
+ Training Format + 16 kHz / mono / 16-bit WAV +
+
+
+ +
+
+ Not started + +
+
+
-
- -

Training log

-
(no training started)
+ - \ No newline at end of file + diff --git a/recorder_server.py b/trainer_server.py similarity index 57% rename from recorder_server.py rename to trainer_server.py index ffc6f91..4e9eb70 100644 --- a/recorder_server.py +++ b/trainer_server.py @@ -1,13 +1,18 @@ -# recorder_server.py +# trainer_server.py +import io import os import re import json import shutil import subprocess +import tempfile import threading +import time +import wave from datetime import datetime from pathlib import Path from typing import Dict, Any, List, Optional, Tuple +from urllib.request import Request, urlopen from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import HTMLResponse, JSONResponse @@ -26,6 +31,17 @@ PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samp # CLI folder inside repo CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve() +PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator" +PIPER_VOICES_DIR = PIPER_ROOT / "voices" +PIPER_VOICES_INDEX_URL = os.environ.get( + "PIPER_VOICES_INDEX_URL", + "https://huggingface.co/rhasspy/piper-voices/raw/main/voices.json", +) +PIPER_VOICES_ROOT_URL = os.environ.get( + "PIPER_VOICES_ROOT_URL", + "https://huggingface.co/rhasspy/piper-voices/resolve/main", +) +PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900")) DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y") DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y") @@ -39,13 +55,16 @@ DEFAULT_LANGUAGE = os.environ.get("MWW_LANGUAGE", "en") TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10")) SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1")) +TARGET_SAMPLE_RATE = 16000 +TARGET_CHANNELS = 1 +TARGET_SAMPLE_WIDTH_BYTES = 2 # Tail lines shown to UI TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400")) # Safety cap for reads (bytes) to avoid giant file reads TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB -app = FastAPI(title="microWakeWord Personal Recorder") +app = FastAPI(title="microWakeWord Personal Samples") STATIC_DIR.mkdir(parents=True, exist_ok=True) app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") @@ -84,6 +103,12 @@ STATE: Dict[str, Any] = { } STATE_LOCK = threading.Lock() +SAMPLES_LOCK = threading.Lock() +PIPER_CATALOG_LOCK = threading.Lock() +PIPER_CATALOG_CACHE: Dict[str, Any] = { + "fetched_at": 0.0, + "entries": None, +} def _reset_personal_samples_dir(): @@ -95,6 +120,362 @@ def _reset_personal_samples_dir(): pass + +def _list_personal_samples() -> List[str]: + PERSONAL_DIR.mkdir(parents=True, exist_ok=True) + return sorted(p.name for p in PERSONAL_DIR.glob("*.wav")) + + +def _sync_personal_samples_state() -> List[str]: + takes = _list_personal_samples() + with STATE_LOCK: + STATE["takes"] = takes + STATE["takes_received"] = len(takes) + return takes + + +def _registered_language_family(language: Dict[str, Any]) -> str: + family = str(language.get("family") or "").strip().lower() + if family: + return family + code = str(language.get("code") or "").strip() + return code.split("_", 1)[0].lower() if code else "" + + +def _register_language( + languages: Dict[str, Dict[str, Any]], + *, + family: str, + name: str, + region: str = "", + count: int = 1, +): + if not family: + return + entry = languages.setdefault( + family, + { + "code": family, + "label": f"{name} ({family})", + "name": name, + "voice_count": 0, + "regions": [], + }, + ) + entry["voice_count"] += count + if region and region not in entry["regions"]: + entry["regions"].append(region) + + +def _fetch_piper_catalog() -> Optional[Dict[str, Any]]: + req = Request( + PIPER_VOICES_INDEX_URL, + headers={"User-Agent": "microWakeWord-Trainer/1.0"}, + ) + with urlopen(req, timeout=15) as resp: + data = json.loads(resp.read().decode("utf-8")) + return data if isinstance(data, dict) else None + + +def _load_piper_catalog() -> Optional[Dict[str, Any]]: + now = time.time() + with PIPER_CATALOG_LOCK: + cached = PIPER_CATALOG_CACHE.get("entries") + fetched_at = float(PIPER_CATALOG_CACHE.get("fetched_at") or 0.0) + if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS: + return cached + + try: + fresh = _fetch_piper_catalog() + except Exception: + fresh = None + + with PIPER_CATALOG_LOCK: + if fresh is not None: + PIPER_CATALOG_CACHE["entries"] = fresh + PIPER_CATALOG_CACHE["fetched_at"] = now + return fresh + if PIPER_CATALOG_CACHE.get("entries") is None: + PIPER_CATALOG_CACHE["entries"] = {} + PIPER_CATALOG_CACHE["fetched_at"] = now + return PIPER_CATALOG_CACHE.get("entries") + + +def _available_languages() -> List[Dict[str, Any]]: + languages: Dict[str, Dict[str, Any]] = { + "en": { + "code": "en", + "label": "English (en)", + "name": "English", + "voice_count": 1, + "regions": [], + } + } + + if PIPER_VOICES_DIR.exists(): + for meta_path in sorted(PIPER_VOICES_DIR.glob("*.onnx.json")): + try: + data = json.loads(meta_path.read_text(encoding="utf-8")) + except Exception: + continue + + language = data.get("language") or {} + family = _registered_language_family(language) + if not family or family == "en": + continue + + name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip() + region = str(language.get("country_english") or language.get("region") or "").strip() + _register_language(languages, family=family, name=name, region=region, count=1) + + catalog = _load_piper_catalog() or {} + for entry in catalog.values(): + if not isinstance(entry, dict): + continue + language = entry.get("language") or {} + family = _registered_language_family(language) + if not family or family == "en": + continue + name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip() + region = str(language.get("country_english") or language.get("region") or "").strip() + _register_language(languages, family=family, name=name, region=region, count=0) + + ordered = [languages["en"]] + ordered.extend( + sorted( + (entry for code, entry in languages.items() if code != "en"), + key=lambda entry: (entry["name"].lower(), entry["code"]), + ) + ) + return ordered + + +def _normalize_language(language: str | None) -> str: + requested = (language or DEFAULT_LANGUAGE).strip().lower() or DEFAULT_LANGUAGE + available_codes = {item["code"] for item in _available_languages()} + if requested in available_codes: + return requested + if DEFAULT_LANGUAGE in available_codes: + return DEFAULT_LANGUAGE + return "en" + + +def _catalog_voice_files(language_family: str) -> List[Tuple[str, str]]: + if not language_family or language_family == "en": + return [] + + downloads: Dict[str, str] = {} + catalog = _load_piper_catalog() or {} + for entry in catalog.values(): + if not isinstance(entry, dict): + continue + language = entry.get("language") or {} + family = _registered_language_family(language) + if family != language_family: + continue + files = entry.get("files") or {} + for rel_path in files.keys(): + if not isinstance(rel_path, str): + continue + if not (rel_path.endswith(".onnx") or rel_path.endswith(".onnx.json")): + continue + downloads[Path(rel_path).name] = f"{PIPER_VOICES_ROOT_URL}/{rel_path}?download=true" + + return sorted(downloads.items(), key=lambda item: item[0]) + + +def _download_to_path(url: str, dest_path: Path): + dest_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp") + req = Request(url, headers={"User-Agent": "microWakeWord-Trainer/1.0"}) + with urlopen(req, timeout=60) as resp, open(tmp_path, "wb") as out: + shutil.copyfileobj(resp, out) + tmp_path.replace(dest_path) + + +def _ensure_non_english_language_voices(language_family: str, log) -> Dict[str, int]: + downloads = _catalog_voice_files(language_family) + local_voices = sorted(PIPER_VOICES_DIR.glob(f"{language_family}_*.onnx")) if PIPER_VOICES_DIR.exists() else [] + if not downloads: + if local_voices: + log(f"===== Piper Voices ({language_family}) =====") + log(f"→ Using {len(local_voices)} installed voice(s) for language '{language_family}'") + return { + "downloaded_files": 0, + "existing_files": len(local_voices), + "voices": len(local_voices), + } + raise RuntimeError( + f"No Piper ONNX voices found for language '{language_family}' in the upstream catalog." + ) + + PIPER_VOICES_DIR.mkdir(parents=True, exist_ok=True) + + downloaded_files = 0 + existing_files = 0 + voice_names = sorted(name for name, _ in downloads if name.endswith(".onnx")) + + log(f"===== Piper Voices ({language_family}) =====") + log(f"→ Ensuring {len(voice_names)} voice(s) for language '{language_family}'") + + for file_name, url in downloads: + dest_path = PIPER_VOICES_DIR / file_name + if dest_path.exists() and dest_path.stat().st_size > 0: + existing_files += 1 + continue + log(f"→ Downloading {file_name}") + _download_to_path(url, dest_path) + downloaded_files += 1 + + log( + f"✓ Piper voices ready for '{language_family}' " + f"({downloaded_files} file(s) downloaded, {existing_files} already present)" + ) + return { + "downloaded_files": downloaded_files, + "existing_files": existing_files, + "voices": len(voice_names), + } + + +def _find_ffmpeg() -> Optional[str]: + candidates = [ + shutil.which("ffmpeg"), + "/usr/bin/ffmpeg", + "/usr/local/bin/ffmpeg", + "/opt/homebrew/bin/ffmpeg", + "/opt/homebrew/opt/ffmpeg@7/bin/ffmpeg", + "/opt/homebrew/opt/ffmpeg/bin/ffmpeg", + ] + for candidate in candidates: + if candidate and Path(candidate).exists(): + return candidate + return None + + +def _inspect_wav_bytes(data: bytes) -> Optional[Dict[str, Any]]: + try: + with wave.open(io.BytesIO(data), "rb") as wf: + frames = wf.getnframes() + rate = wf.getframerate() + duration = (frames / rate) if rate else 0.0 + return { + "container": "wav", + "sample_rate": rate, + "channels": wf.getnchannels(), + "sample_width_bits": wf.getsampwidth() * 8, + "compression": wf.getcomptype(), + "frames": frames, + "duration_s": round(duration, 3), + } + except Exception: + return None + + +def _is_target_wav(info: Optional[Dict[str, Any]]) -> bool: + return bool( + info + and info.get("container") == "wav" + and info.get("sample_rate") == TARGET_SAMPLE_RATE + and info.get("channels") == TARGET_CHANNELS + and info.get("sample_width_bits") == TARGET_SAMPLE_WIDTH_BYTES * 8 + and info.get("compression") == "NONE" + and info.get("frames", 0) > 0 + ) + + +def _next_personal_sample_name(original_name: str) -> str: + current = _list_personal_samples() + next_index = 1 + for name in current: + match = re.match(r"sample_(\d{4})", name) + if match: + next_index = max(next_index, int(match.group(1)) + 1) + + stem = safe_name(Path(original_name or "sample").stem) + suffix = f"_{stem[:32]}" if stem and stem != "wakeword" else "" + return f"sample_{next_index:04d}{suffix}.wav" + + +def _format_hint_from_filename(original_name: str) -> Dict[str, Any]: + suffix = (Path(original_name or "").suffix or "").lower().lstrip(".") + return { + "container": suffix or "unknown", + "sample_rate": None, + "channels": None, + "sample_width_bits": None, + "compression": None, + "frames": None, + "duration_s": None, + } + + +def _normalize_audio_to_target_wav(data: bytes, original_name: str) -> bytes: + ffmpeg = _find_ffmpeg() + if not ffmpeg: + raise RuntimeError( + "ffmpeg is required to convert uploads that are not already 16 kHz mono 16-bit PCM WAV." + ) + + suffix = (Path(original_name or "").suffix or ".audio") + with tempfile.TemporaryDirectory(prefix="mww_upload_") as tmpdir: + src_path = Path(tmpdir) / f"source{suffix}" + dst_path = Path(tmpdir) / "normalized.wav" + src_path.write_bytes(data) + + cmd = [ + ffmpeg, + "-y", + "-i", + str(src_path), + "-vn", + "-ac", + str(TARGET_CHANNELS), + "-ar", + str(TARGET_SAMPLE_RATE), + "-c:a", + "pcm_s16le", + str(dst_path), + ] + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0 or not dst_path.exists(): + err = (proc.stderr or proc.stdout or "ffmpeg conversion failed").strip() + raise RuntimeError(err.splitlines()[-1] if err else "ffmpeg conversion failed") + + return dst_path.read_bytes() + + +def _save_personal_sample(data: bytes, original_name: str, out_name: Optional[str] = None) -> Dict[str, Any]: + if not data: + raise ValueError("Empty or invalid audio file.") + + original_info = _inspect_wav_bytes(data) or _format_hint_from_filename(original_name) + normalized = _is_target_wav(original_info) + final_bytes = data if normalized else _normalize_audio_to_target_wav(data, original_name) + final_info = _inspect_wav_bytes(final_bytes) + + if not _is_target_wav(final_info): + raise ValueError("Uploaded audio could not be normalized to 16 kHz mono 16-bit PCM WAV.") + + with SAMPLES_LOCK: + PERSONAL_DIR.mkdir(parents=True, exist_ok=True) + final_name = out_name or _next_personal_sample_name(original_name) + out_path = PERSONAL_DIR / final_name + out_path.write_bytes(final_bytes) + + return { + "saved_as": final_name, + "converted": not normalized, + "original_name": original_name or final_name, + "detected_format": original_info, + "final_format": final_info, + "message": ( + "Converted to 16 kHz mono 16-bit PCM WAV" + if not normalized + else "Already in the correct 16 kHz mono 16-bit PCM WAV format" + ), + } + def _clear_training_log(): """ Truncate recorder_training.log for a fresh session. @@ -405,6 +786,8 @@ def _run_training_background(safe_word: str, allow_no_personal: bool): try: _ensure_training_venv(log_path) _ensure_training_datasets(log_path) + if language != "en": + _ensure_non_english_language_voices(language, _append_train_log) if wake_word_title: cmd_str = f"{TRAIN_CMD} --language='{language}' '{safe_word}' '{wake_word_title}'" @@ -474,7 +857,8 @@ def start_session(payload: Dict[str, Any]): speakers_total = int(payload.get("speakers_total") or SPEAKERS_TOTAL_DEFAULT) takes_per_speaker = int(payload.get("takes_per_speaker") or TAKES_PER_SPEAKER_DEFAULT) - language = (payload.get("language") or DEFAULT_LANGUAGE).strip().lower() + language = _normalize_language(payload.get("language")) + available_languages = _available_languages() speakers_total = max(1, min(10, speakers_total)) takes_per_speaker = max(1, min(50, takes_per_speaker)) @@ -485,10 +869,8 @@ def start_session(payload: Dict[str, Any]): STATE["language"] = language STATE["speakers_total"] = speakers_total STATE["takes_per_speaker"] = takes_per_speaker - STATE["takes_received"] = 0 - STATE["takes"] = [] - _reset_personal_samples_dir() + takes = _sync_personal_samples_state() # Always wipe log on start_session (even if same wakeword) _clear_training_log() @@ -501,6 +883,9 @@ def start_session(payload: Dict[str, Any]): "speakers_total": speakers_total, "takes_per_speaker": takes_per_speaker, "takes_total": speakers_total * takes_per_speaker, + "takes_received": len(takes), + "takes": takes, + "available_languages": available_languages, "personal_dir": str(PERSONAL_DIR), "data_dir": str(DATA_DIR), } @@ -508,17 +893,22 @@ def start_session(payload: Dict[str, Any]): @app.get("/api/session") def get_session(): + takes = _sync_personal_samples_state() + available_languages = _available_languages() with STATE_LOCK: + current_language = _normalize_language(STATE["language"]) + STATE["language"] = current_language return { "ok": True, "raw_phrase": STATE["raw_phrase"], "safe_word": STATE["safe_word"], - "language": STATE["language"], + "language": current_language, "speakers_total": STATE["speakers_total"], "takes_per_speaker": STATE["takes_per_speaker"], - "takes_received": STATE["takes_received"], - "takes": list(STATE["takes"]), + "takes_received": len(takes), + "takes": list(takes), "training": dict(STATE["training"]), + "available_languages": available_languages, } @@ -542,23 +932,34 @@ async def upload_take( if take_index < 1 or take_index > takes_per_speaker: return JSONResponse({"ok": False, "error": f"take_index must be 1..{takes_per_speaker}"}, status_code=400) - PERSONAL_DIR.mkdir(parents=True, exist_ok=True) - out_name = f"speaker{speaker_index:02d}_take{take_index:02d}.wav" - out_path = PERSONAL_DIR / out_name data = await file.read() - if not data or len(data) < 44: - return JSONResponse({"ok": False, "error": "Empty/invalid file"}, status_code=400) + try: + result = _save_personal_sample(data, file.filename or out_name, out_name=out_name) + except Exception as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=400) - out_path.write_bytes(data) + takes = _sync_personal_samples_state() + return {"ok": True, **result, "takes_received": len(takes)} + +@app.post("/api/upload_personal_sample") +async def upload_personal_sample(file: UploadFile = File(...)): with STATE_LOCK: - if out_name not in STATE["takes"]: - STATE["takes"].append(out_name) - STATE["takes_received"] = len(STATE["takes"]) + safe_word = STATE["safe_word"] - return {"ok": True, "saved_as": out_name, "takes_received": STATE["takes_received"]} + if not safe_word: + return JSONResponse({"ok": False, "error": "No active session. Call /api/start_session first."}, status_code=400) + + data = await file.read() + try: + result = _save_personal_sample(data, file.filename or "sample") + except Exception as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=400) + + takes = _sync_personal_samples_state() + return {"ok": True, **result, "takes_received": len(takes)} @app.post("/api/train") @@ -581,27 +982,13 @@ def train_now(payload: Dict[str, Any] = None): if not safe_word: return JSONResponse({"ok": False, "error": "No active session"}, status_code=400) - min_required = max(1, min(3, takes_total)) - if takes_received == 0 and not allow_no_personal: return JSONResponse( { "ok": False, - "error": f"No personal voice samples recorded (0/{takes_total}).", + "error": "No personal voice samples uploaded yet.", "code": "NO_PERSONAL_SAMPLES", - "message": "You can train without personal voices, or record samples first.", - "takes_total": takes_total, - }, - status_code=400, - ) - - if 0 < takes_received < min_required: - return JSONResponse( - { - "ok": False, - "error": f"Not enough takes yet ({takes_received}/{takes_total}).", - "code": "NOT_ENOUGH_TAKES", - "min_required": min_required, + "message": "You can train without personal voices, or upload samples first.", "takes_total": takes_total, }, status_code=400, @@ -614,7 +1001,7 @@ def train_now(payload: Dict[str, Any] = None): "ok": True, "started": True, "safe_word": safe_word, - "personal_samples_used": takes_received >= min_required, + "personal_samples_used": takes_received > 0, "allow_no_personal": allow_no_personal, } @@ -656,13 +1043,12 @@ def train_status(): tr["log_text"] = "\n".join(new_lines) # ONLY new lines tr["log_tail_preview"] = "\n".join(full_tail) # optional: handy for debugging + tr["log_lines"] = full_tail return {"ok": True, "training": tr} @app.post("/api/reset_recordings") def reset_recordings(): _reset_personal_samples_dir() - with STATE_LOCK: - STATE["takes_received"] = 0 - STATE["takes"] = [] - return {"ok": True} \ No newline at end of file + takes = _sync_personal_samples_state() + return {"ok": True, "takes_received": len(takes), "takes": takes}