remove recorder/upload samples

This commit is contained in:
MasterPhooey
2026-04-14 22:55:49 -05:00
parent 15b2fe9c9a
commit 45583027a4
6 changed files with 1255 additions and 647 deletions

View File

@@ -0,0 +1,112 @@
#!/usr/bin/env python3
import argparse
import queue
import subprocess
import sys
import threading
from pathlib import Path
def _model_args(generator_args):
values = []
for idx, arg in enumerate(generator_args):
if arg == "--model" and idx + 1 < len(generator_args):
values.append(generator_args[idx + 1])
return values
def _is_onnx_run(generator_args):
return any(str(value).endswith(".onnx") for value in _model_args(generator_args))
def _format_line(line):
if line.startswith("DEBUG:piper.voice:"):
return None
for prefix in ("DEBUG:__main__:", "INFO:__main__:", "WARNING:__main__:", "ERROR:__main__:"):
if line.startswith(prefix):
return " " + line[len(prefix):].strip()
return line
def _reader(stdout, sink):
try:
for raw in stdout:
sink.put(raw.rstrip("\n"))
finally:
sink.put(None)
def _progress_step(max_samples):
if max_samples <= 20:
return 1
if max_samples <= 100:
return 5
return 10
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--generator", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--max-samples", required=True, type=int)
parser.add_argument("generator_args", nargs=argparse.REMAINDER)
args = parser.parse_args()
generator_args = list(args.generator_args)
if generator_args and generator_args[0] == "--":
generator_args = generator_args[1:]
cmd = [sys.executable, args.generator, *generator_args]
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
assert proc.stdout is not None
line_queue = queue.Queue()
reader = threading.Thread(target=_reader, args=(proc.stdout, line_queue), daemon=True)
reader.start()
output_dir = Path(args.output_dir)
use_sample_progress = _is_onnx_run(generator_args)
step = _progress_step(args.max_samples)
last_reported = 0
stream_done = False
while not stream_done or proc.poll() is None:
try:
line = line_queue.get(timeout=0.2)
except queue.Empty:
line = None
if line is None:
if not stream_done and not line_queue.empty():
continue
stream_done = proc.poll() is not None or stream_done
else:
formatted = _format_line(line)
if formatted:
print(formatted, flush=True)
if use_sample_progress:
current = len(list(output_dir.glob("*.wav")))
should_report = current > last_reported and (
current >= args.max_samples
or current - last_reported >= step
)
if should_report:
print(f" Generated {current}/{args.max_samples} samples...", flush=True)
last_reported = current
rc = proc.wait()
final_count = len(list(output_dir.glob("*.wav"))) if use_sample_progress else 0
if use_sample_progress and final_count > last_reported:
print(f" Generated {final_count}/{args.max_samples} samples...", flush=True)
return rc
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -138,11 +138,16 @@ export GRPC_VERBOSITY=ERROR
echo " Generating samples"
rm -rf "${SAMPLES_DIR}" || :
mkdir -p "${SAMPLES_DIR}" || :
"${PSG}/generate_samples.py" "${WAKE_WORD}" \
python "${PROGDIR}/run_generator_with_progress.py" \
--generator "${PSG}/generate_samples.py" \
--output-dir "${SAMPLES_DIR}" \
--max-samples ${SAMPLES} \
-- \
"${WAKE_WORD}" \
"${MODEL_ARGS[@]}" \
--max-samples ${SAMPLES} \
--batch-size ${BATCH_SIZE} \
--output-dir "${SAMPLES_DIR}" 2>&1 | sed -r -e "s/(DEBUG|INFO):__main__:/ /g"
--output-dir "${SAMPLES_DIR}"
generated_files=$(find "${SAMPLES_DIR}" -name '*.wav' | wc -l)
if [ "${generated_files}" -ne "${SAMPLES}" ] ; then

View File

@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& rm -rf /var/lib/apt/lists/* \
&& mkdir -p /data
# Recorder port
# Trainer UI port
EXPOSE 8789
# Script root
@@ -23,7 +23,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/
COPY --chown=root:root --chmod=0755 \
train_wake_word \
run_recorder.sh \
recorder_server.py \
trainer_server.py \
requirements.txt \
/root/mww-scripts/
@@ -33,8 +33,8 @@ COPY --chown=root:root cli/ /root/mww-scripts/cli/
# Make all CLI scripts executable (avoids "Permission denied")
RUN chmod -R a+x /root/mww-scripts/cli
# Static UI for recorder
# Static UI for trainer
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
# recorder server
# trainer server
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"]

View File

@@ -18,7 +18,7 @@ FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
echo "microWakeWord Recorder (Docker)"
echo "microWakeWord Trainer UI (Docker)"
echo "-> ROOTDIR: ${ROOTDIR}"
echo "-> DATA_DIR: ${DATA_DIR}"
echo "-> URL: http://localhost:${PORT}/"
@@ -26,10 +26,10 @@ echo "-> URL: http://localhost:${PORT}/"
mkdir -p "${DATA_DIR}"
# -----------------------------
# Recorder venv (separate)
# Trainer UI venv (separate)
# -----------------------------
if [[ ! -x "${PY}" ]]; then
echo "Creating recorder venv: ${VENV_DIR}"
echo "Creating trainer UI venv: ${VENV_DIR}"
python3 -m venv "${VENV_DIR}"
fi
@@ -37,7 +37,7 @@ fi
source "${VENV_DIR}/bin/activate"
if [[ ! -f "${PIN_FILE}" ]]; then
echo "Installing pinned recorder deps"
echo "Installing pinned trainer UI deps"
${PIP} install -U pip setuptools wheel
${PIP} install \
"fastapi==${FASTAPI_VERSION}" \
@@ -45,20 +45,20 @@ if [[ ! -f "${PIN_FILE}" ]]; then
"python-multipart==${PY_MULTIPART_VERSION}"
touch "${PIN_FILE}"
else
echo "Reusing existing recorder venv (no upgrades)"
echo "Reusing existing trainer UI venv (no upgrades)"
fi
# -----------------------------
# Recorder server env
# Trainer server env
# -----------------------------
export DATA_DIR="${DATA_DIR}"
export STATIC_DIR="${ROOTDIR}/static"
export PERSONAL_DIR="${DATA_DIR}/personal_samples"
# IMPORTANT: leave training venv creation to /api/train inside recorder_server.py
# IMPORTANT: leave training venv creation to /api/train inside trainer_server.py
# but still set TRAIN_CMD so the server knows how to invoke training once ready
export TRAIN_CMD="source '${DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir='${DATA_DIR}'"
echo "Launching uvicorn on ${HOST}:${PORT}"
cd "${ROOTDIR}"
exec "${VENV_DIR}/bin/uvicorn" recorder_server:app --host "${HOST}" --port "${PORT}"
exec "${VENV_DIR}/bin/uvicorn" trainer_server:app --host "${HOST}" --port "${PORT}"

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,18 @@
# recorder_server.py
# trainer_server.py
import io
import os
import re
import json
import shutil
import subprocess
import tempfile
import threading
import time
import wave
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from urllib.request import Request, urlopen
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import HTMLResponse, JSONResponse
@@ -26,6 +31,17 @@ PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samp
# CLI folder inside repo
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator"
PIPER_VOICES_DIR = PIPER_ROOT / "voices"
PIPER_VOICES_INDEX_URL = os.environ.get(
"PIPER_VOICES_INDEX_URL",
"https://huggingface.co/rhasspy/piper-voices/raw/main/voices.json",
)
PIPER_VOICES_ROOT_URL = os.environ.get(
"PIPER_VOICES_ROOT_URL",
"https://huggingface.co/rhasspy/piper-voices/resolve/main",
)
PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900"))
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
@@ -39,13 +55,16 @@ DEFAULT_LANGUAGE = os.environ.get("MWW_LANGUAGE", "en")
TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10"))
SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1"))
TARGET_SAMPLE_RATE = 16000
TARGET_CHANNELS = 1
TARGET_SAMPLE_WIDTH_BYTES = 2
# Tail lines shown to UI
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
# Safety cap for reads (bytes) to avoid giant file reads
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB
app = FastAPI(title="microWakeWord Personal Recorder")
app = FastAPI(title="microWakeWord Personal Samples")
STATIC_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@@ -84,6 +103,12 @@ STATE: Dict[str, Any] = {
}
STATE_LOCK = threading.Lock()
SAMPLES_LOCK = threading.Lock()
PIPER_CATALOG_LOCK = threading.Lock()
PIPER_CATALOG_CACHE: Dict[str, Any] = {
"fetched_at": 0.0,
"entries": None,
}
def _reset_personal_samples_dir():
@@ -95,6 +120,362 @@ def _reset_personal_samples_dir():
pass
def _list_personal_samples() -> List[str]:
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
return sorted(p.name for p in PERSONAL_DIR.glob("*.wav"))
def _sync_personal_samples_state() -> List[str]:
takes = _list_personal_samples()
with STATE_LOCK:
STATE["takes"] = takes
STATE["takes_received"] = len(takes)
return takes
def _registered_language_family(language: Dict[str, Any]) -> str:
family = str(language.get("family") or "").strip().lower()
if family:
return family
code = str(language.get("code") or "").strip()
return code.split("_", 1)[0].lower() if code else ""
def _register_language(
languages: Dict[str, Dict[str, Any]],
*,
family: str,
name: str,
region: str = "",
count: int = 1,
):
if not family:
return
entry = languages.setdefault(
family,
{
"code": family,
"label": f"{name} ({family})",
"name": name,
"voice_count": 0,
"regions": [],
},
)
entry["voice_count"] += count
if region and region not in entry["regions"]:
entry["regions"].append(region)
def _fetch_piper_catalog() -> Optional[Dict[str, Any]]:
req = Request(
PIPER_VOICES_INDEX_URL,
headers={"User-Agent": "microWakeWord-Trainer/1.0"},
)
with urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data if isinstance(data, dict) else None
def _load_piper_catalog() -> Optional[Dict[str, Any]]:
now = time.time()
with PIPER_CATALOG_LOCK:
cached = PIPER_CATALOG_CACHE.get("entries")
fetched_at = float(PIPER_CATALOG_CACHE.get("fetched_at") or 0.0)
if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS:
return cached
try:
fresh = _fetch_piper_catalog()
except Exception:
fresh = None
with PIPER_CATALOG_LOCK:
if fresh is not None:
PIPER_CATALOG_CACHE["entries"] = fresh
PIPER_CATALOG_CACHE["fetched_at"] = now
return fresh
if PIPER_CATALOG_CACHE.get("entries") is None:
PIPER_CATALOG_CACHE["entries"] = {}
PIPER_CATALOG_CACHE["fetched_at"] = now
return PIPER_CATALOG_CACHE.get("entries")
def _available_languages() -> List[Dict[str, Any]]:
languages: Dict[str, Dict[str, Any]] = {
"en": {
"code": "en",
"label": "English (en)",
"name": "English",
"voice_count": 1,
"regions": [],
}
}
if PIPER_VOICES_DIR.exists():
for meta_path in sorted(PIPER_VOICES_DIR.glob("*.onnx.json")):
try:
data = json.loads(meta_path.read_text(encoding="utf-8"))
except Exception:
continue
language = data.get("language") or {}
family = _registered_language_family(language)
if not family or family == "en":
continue
name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip()
region = str(language.get("country_english") or language.get("region") or "").strip()
_register_language(languages, family=family, name=name, region=region, count=1)
catalog = _load_piper_catalog() or {}
for entry in catalog.values():
if not isinstance(entry, dict):
continue
language = entry.get("language") or {}
family = _registered_language_family(language)
if not family or family == "en":
continue
name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip()
region = str(language.get("country_english") or language.get("region") or "").strip()
_register_language(languages, family=family, name=name, region=region, count=0)
ordered = [languages["en"]]
ordered.extend(
sorted(
(entry for code, entry in languages.items() if code != "en"),
key=lambda entry: (entry["name"].lower(), entry["code"]),
)
)
return ordered
def _normalize_language(language: str | None) -> str:
requested = (language or DEFAULT_LANGUAGE).strip().lower() or DEFAULT_LANGUAGE
available_codes = {item["code"] for item in _available_languages()}
if requested in available_codes:
return requested
if DEFAULT_LANGUAGE in available_codes:
return DEFAULT_LANGUAGE
return "en"
def _catalog_voice_files(language_family: str) -> List[Tuple[str, str]]:
if not language_family or language_family == "en":
return []
downloads: Dict[str, str] = {}
catalog = _load_piper_catalog() or {}
for entry in catalog.values():
if not isinstance(entry, dict):
continue
language = entry.get("language") or {}
family = _registered_language_family(language)
if family != language_family:
continue
files = entry.get("files") or {}
for rel_path in files.keys():
if not isinstance(rel_path, str):
continue
if not (rel_path.endswith(".onnx") or rel_path.endswith(".onnx.json")):
continue
downloads[Path(rel_path).name] = f"{PIPER_VOICES_ROOT_URL}/{rel_path}?download=true"
return sorted(downloads.items(), key=lambda item: item[0])
def _download_to_path(url: str, dest_path: Path):
dest_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp")
req = Request(url, headers={"User-Agent": "microWakeWord-Trainer/1.0"})
with urlopen(req, timeout=60) as resp, open(tmp_path, "wb") as out:
shutil.copyfileobj(resp, out)
tmp_path.replace(dest_path)
def _ensure_non_english_language_voices(language_family: str, log) -> Dict[str, int]:
downloads = _catalog_voice_files(language_family)
local_voices = sorted(PIPER_VOICES_DIR.glob(f"{language_family}_*.onnx")) if PIPER_VOICES_DIR.exists() else []
if not downloads:
if local_voices:
log(f"===== Piper Voices ({language_family}) =====")
log(f"→ Using {len(local_voices)} installed voice(s) for language '{language_family}'")
return {
"downloaded_files": 0,
"existing_files": len(local_voices),
"voices": len(local_voices),
}
raise RuntimeError(
f"No Piper ONNX voices found for language '{language_family}' in the upstream catalog."
)
PIPER_VOICES_DIR.mkdir(parents=True, exist_ok=True)
downloaded_files = 0
existing_files = 0
voice_names = sorted(name for name, _ in downloads if name.endswith(".onnx"))
log(f"===== Piper Voices ({language_family}) =====")
log(f"→ Ensuring {len(voice_names)} voice(s) for language '{language_family}'")
for file_name, url in downloads:
dest_path = PIPER_VOICES_DIR / file_name
if dest_path.exists() and dest_path.stat().st_size > 0:
existing_files += 1
continue
log(f"→ Downloading {file_name}")
_download_to_path(url, dest_path)
downloaded_files += 1
log(
f"✓ Piper voices ready for '{language_family}' "
f"({downloaded_files} file(s) downloaded, {existing_files} already present)"
)
return {
"downloaded_files": downloaded_files,
"existing_files": existing_files,
"voices": len(voice_names),
}
def _find_ffmpeg() -> Optional[str]:
candidates = [
shutil.which("ffmpeg"),
"/usr/bin/ffmpeg",
"/usr/local/bin/ffmpeg",
"/opt/homebrew/bin/ffmpeg",
"/opt/homebrew/opt/ffmpeg@7/bin/ffmpeg",
"/opt/homebrew/opt/ffmpeg/bin/ffmpeg",
]
for candidate in candidates:
if candidate and Path(candidate).exists():
return candidate
return None
def _inspect_wav_bytes(data: bytes) -> Optional[Dict[str, Any]]:
try:
with wave.open(io.BytesIO(data), "rb") as wf:
frames = wf.getnframes()
rate = wf.getframerate()
duration = (frames / rate) if rate else 0.0
return {
"container": "wav",
"sample_rate": rate,
"channels": wf.getnchannels(),
"sample_width_bits": wf.getsampwidth() * 8,
"compression": wf.getcomptype(),
"frames": frames,
"duration_s": round(duration, 3),
}
except Exception:
return None
def _is_target_wav(info: Optional[Dict[str, Any]]) -> bool:
return bool(
info
and info.get("container") == "wav"
and info.get("sample_rate") == TARGET_SAMPLE_RATE
and info.get("channels") == TARGET_CHANNELS
and info.get("sample_width_bits") == TARGET_SAMPLE_WIDTH_BYTES * 8
and info.get("compression") == "NONE"
and info.get("frames", 0) > 0
)
def _next_personal_sample_name(original_name: str) -> str:
current = _list_personal_samples()
next_index = 1
for name in current:
match = re.match(r"sample_(\d{4})", name)
if match:
next_index = max(next_index, int(match.group(1)) + 1)
stem = safe_name(Path(original_name or "sample").stem)
suffix = f"_{stem[:32]}" if stem and stem != "wakeword" else ""
return f"sample_{next_index:04d}{suffix}.wav"
def _format_hint_from_filename(original_name: str) -> Dict[str, Any]:
suffix = (Path(original_name or "").suffix or "").lower().lstrip(".")
return {
"container": suffix or "unknown",
"sample_rate": None,
"channels": None,
"sample_width_bits": None,
"compression": None,
"frames": None,
"duration_s": None,
}
def _normalize_audio_to_target_wav(data: bytes, original_name: str) -> bytes:
ffmpeg = _find_ffmpeg()
if not ffmpeg:
raise RuntimeError(
"ffmpeg is required to convert uploads that are not already 16 kHz mono 16-bit PCM WAV."
)
suffix = (Path(original_name or "").suffix or ".audio")
with tempfile.TemporaryDirectory(prefix="mww_upload_") as tmpdir:
src_path = Path(tmpdir) / f"source{suffix}"
dst_path = Path(tmpdir) / "normalized.wav"
src_path.write_bytes(data)
cmd = [
ffmpeg,
"-y",
"-i",
str(src_path),
"-vn",
"-ac",
str(TARGET_CHANNELS),
"-ar",
str(TARGET_SAMPLE_RATE),
"-c:a",
"pcm_s16le",
str(dst_path),
]
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0 or not dst_path.exists():
err = (proc.stderr or proc.stdout or "ffmpeg conversion failed").strip()
raise RuntimeError(err.splitlines()[-1] if err else "ffmpeg conversion failed")
return dst_path.read_bytes()
def _save_personal_sample(data: bytes, original_name: str, out_name: Optional[str] = None) -> Dict[str, Any]:
if not data:
raise ValueError("Empty or invalid audio file.")
original_info = _inspect_wav_bytes(data) or _format_hint_from_filename(original_name)
normalized = _is_target_wav(original_info)
final_bytes = data if normalized else _normalize_audio_to_target_wav(data, original_name)
final_info = _inspect_wav_bytes(final_bytes)
if not _is_target_wav(final_info):
raise ValueError("Uploaded audio could not be normalized to 16 kHz mono 16-bit PCM WAV.")
with SAMPLES_LOCK:
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
final_name = out_name or _next_personal_sample_name(original_name)
out_path = PERSONAL_DIR / final_name
out_path.write_bytes(final_bytes)
return {
"saved_as": final_name,
"converted": not normalized,
"original_name": original_name or final_name,
"detected_format": original_info,
"final_format": final_info,
"message": (
"Converted to 16 kHz mono 16-bit PCM WAV"
if not normalized
else "Already in the correct 16 kHz mono 16-bit PCM WAV format"
),
}
def _clear_training_log():
"""
Truncate recorder_training.log for a fresh session.
@@ -405,6 +786,8 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
try:
_ensure_training_venv(log_path)
_ensure_training_datasets(log_path)
if language != "en":
_ensure_non_english_language_voices(language, _append_train_log)
if wake_word_title:
cmd_str = f"{TRAIN_CMD} --language='{language}' '{safe_word}' '{wake_word_title}'"
@@ -474,7 +857,8 @@ def start_session(payload: Dict[str, Any]):
speakers_total = int(payload.get("speakers_total") or SPEAKERS_TOTAL_DEFAULT)
takes_per_speaker = int(payload.get("takes_per_speaker") or TAKES_PER_SPEAKER_DEFAULT)
language = (payload.get("language") or DEFAULT_LANGUAGE).strip().lower()
language = _normalize_language(payload.get("language"))
available_languages = _available_languages()
speakers_total = max(1, min(10, speakers_total))
takes_per_speaker = max(1, min(50, takes_per_speaker))
@@ -485,10 +869,8 @@ def start_session(payload: Dict[str, Any]):
STATE["language"] = language
STATE["speakers_total"] = speakers_total
STATE["takes_per_speaker"] = takes_per_speaker
STATE["takes_received"] = 0
STATE["takes"] = []
_reset_personal_samples_dir()
takes = _sync_personal_samples_state()
# Always wipe log on start_session (even if same wakeword)
_clear_training_log()
@@ -501,6 +883,9 @@ def start_session(payload: Dict[str, Any]):
"speakers_total": speakers_total,
"takes_per_speaker": takes_per_speaker,
"takes_total": speakers_total * takes_per_speaker,
"takes_received": len(takes),
"takes": takes,
"available_languages": available_languages,
"personal_dir": str(PERSONAL_DIR),
"data_dir": str(DATA_DIR),
}
@@ -508,17 +893,22 @@ def start_session(payload: Dict[str, Any]):
@app.get("/api/session")
def get_session():
takes = _sync_personal_samples_state()
available_languages = _available_languages()
with STATE_LOCK:
current_language = _normalize_language(STATE["language"])
STATE["language"] = current_language
return {
"ok": True,
"raw_phrase": STATE["raw_phrase"],
"safe_word": STATE["safe_word"],
"language": STATE["language"],
"language": current_language,
"speakers_total": STATE["speakers_total"],
"takes_per_speaker": STATE["takes_per_speaker"],
"takes_received": STATE["takes_received"],
"takes": list(STATE["takes"]),
"takes_received": len(takes),
"takes": list(takes),
"training": dict(STATE["training"]),
"available_languages": available_languages,
}
@@ -542,23 +932,34 @@ async def upload_take(
if take_index < 1 or take_index > takes_per_speaker:
return JSONResponse({"ok": False, "error": f"take_index must be 1..{takes_per_speaker}"}, status_code=400)
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
out_name = f"speaker{speaker_index:02d}_take{take_index:02d}.wav"
out_path = PERSONAL_DIR / out_name
data = await file.read()
if not data or len(data) < 44:
return JSONResponse({"ok": False, "error": "Empty/invalid file"}, status_code=400)
try:
result = _save_personal_sample(data, file.filename or out_name, out_name=out_name)
except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
out_path.write_bytes(data)
takes = _sync_personal_samples_state()
return {"ok": True, **result, "takes_received": len(takes)}
@app.post("/api/upload_personal_sample")
async def upload_personal_sample(file: UploadFile = File(...)):
with STATE_LOCK:
if out_name not in STATE["takes"]:
STATE["takes"].append(out_name)
STATE["takes_received"] = len(STATE["takes"])
safe_word = STATE["safe_word"]
return {"ok": True, "saved_as": out_name, "takes_received": STATE["takes_received"]}
if not safe_word:
return JSONResponse({"ok": False, "error": "No active session. Call /api/start_session first."}, status_code=400)
data = await file.read()
try:
result = _save_personal_sample(data, file.filename or "sample")
except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
takes = _sync_personal_samples_state()
return {"ok": True, **result, "takes_received": len(takes)}
@app.post("/api/train")
@@ -581,27 +982,13 @@ def train_now(payload: Dict[str, Any] = None):
if not safe_word:
return JSONResponse({"ok": False, "error": "No active session"}, status_code=400)
min_required = max(1, min(3, takes_total))
if takes_received == 0 and not allow_no_personal:
return JSONResponse(
{
"ok": False,
"error": f"No personal voice samples recorded (0/{takes_total}).",
"error": "No personal voice samples uploaded yet.",
"code": "NO_PERSONAL_SAMPLES",
"message": "You can train without personal voices, or record samples first.",
"takes_total": takes_total,
},
status_code=400,
)
if 0 < takes_received < min_required:
return JSONResponse(
{
"ok": False,
"error": f"Not enough takes yet ({takes_received}/{takes_total}).",
"code": "NOT_ENOUGH_TAKES",
"min_required": min_required,
"message": "You can train without personal voices, or upload samples first.",
"takes_total": takes_total,
},
status_code=400,
@@ -614,7 +1001,7 @@ def train_now(payload: Dict[str, Any] = None):
"ok": True,
"started": True,
"safe_word": safe_word,
"personal_samples_used": takes_received >= min_required,
"personal_samples_used": takes_received > 0,
"allow_no_personal": allow_no_personal,
}
@@ -656,13 +1043,12 @@ def train_status():
tr["log_text"] = "\n".join(new_lines) # ONLY new lines
tr["log_tail_preview"] = "\n".join(full_tail) # optional: handy for debugging
tr["log_lines"] = full_tail
return {"ok": True, "training": tr}
@app.post("/api/reset_recordings")
def reset_recordings():
_reset_personal_samples_dir()
with STATE_LOCK:
STATE["takes_received"] = 0
STATE["takes"] = []
return {"ok": True}
takes = _sync_personal_samples_state()
return {"ok": True, "takes_received": len(takes), "takes": takes}