mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
remove recorder/upload samples
This commit is contained in:
112
cli/run_generator_with_progress.py
Normal file
112
cli/run_generator_with_progress.py
Normal file
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import queue
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _model_args(generator_args):
|
||||
values = []
|
||||
for idx, arg in enumerate(generator_args):
|
||||
if arg == "--model" and idx + 1 < len(generator_args):
|
||||
values.append(generator_args[idx + 1])
|
||||
return values
|
||||
|
||||
|
||||
def _is_onnx_run(generator_args):
|
||||
return any(str(value).endswith(".onnx") for value in _model_args(generator_args))
|
||||
|
||||
|
||||
def _format_line(line):
|
||||
if line.startswith("DEBUG:piper.voice:"):
|
||||
return None
|
||||
for prefix in ("DEBUG:__main__:", "INFO:__main__:", "WARNING:__main__:", "ERROR:__main__:"):
|
||||
if line.startswith(prefix):
|
||||
return " " + line[len(prefix):].strip()
|
||||
return line
|
||||
|
||||
|
||||
def _reader(stdout, sink):
|
||||
try:
|
||||
for raw in stdout:
|
||||
sink.put(raw.rstrip("\n"))
|
||||
finally:
|
||||
sink.put(None)
|
||||
|
||||
|
||||
def _progress_step(max_samples):
|
||||
if max_samples <= 20:
|
||||
return 1
|
||||
if max_samples <= 100:
|
||||
return 5
|
||||
return 10
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--generator", required=True)
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument("--max-samples", required=True, type=int)
|
||||
parser.add_argument("generator_args", nargs=argparse.REMAINDER)
|
||||
args = parser.parse_args()
|
||||
|
||||
generator_args = list(args.generator_args)
|
||||
if generator_args and generator_args[0] == "--":
|
||||
generator_args = generator_args[1:]
|
||||
|
||||
cmd = [sys.executable, args.generator, *generator_args]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
assert proc.stdout is not None
|
||||
|
||||
line_queue = queue.Queue()
|
||||
reader = threading.Thread(target=_reader, args=(proc.stdout, line_queue), daemon=True)
|
||||
reader.start()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
use_sample_progress = _is_onnx_run(generator_args)
|
||||
step = _progress_step(args.max_samples)
|
||||
last_reported = 0
|
||||
stream_done = False
|
||||
|
||||
while not stream_done or proc.poll() is None:
|
||||
try:
|
||||
line = line_queue.get(timeout=0.2)
|
||||
except queue.Empty:
|
||||
line = None
|
||||
|
||||
if line is None:
|
||||
if not stream_done and not line_queue.empty():
|
||||
continue
|
||||
stream_done = proc.poll() is not None or stream_done
|
||||
else:
|
||||
formatted = _format_line(line)
|
||||
if formatted:
|
||||
print(formatted, flush=True)
|
||||
|
||||
if use_sample_progress:
|
||||
current = len(list(output_dir.glob("*.wav")))
|
||||
should_report = current > last_reported and (
|
||||
current >= args.max_samples
|
||||
or current - last_reported >= step
|
||||
)
|
||||
if should_report:
|
||||
print(f" Generated {current}/{args.max_samples} samples...", flush=True)
|
||||
last_reported = current
|
||||
|
||||
rc = proc.wait()
|
||||
final_count = len(list(output_dir.glob("*.wav"))) if use_sample_progress else 0
|
||||
if use_sample_progress and final_count > last_reported:
|
||||
print(f" Generated {final_count}/{args.max_samples} samples...", flush=True)
|
||||
return rc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -138,11 +138,16 @@ export GRPC_VERBOSITY=ERROR
|
||||
echo " Generating samples"
|
||||
rm -rf "${SAMPLES_DIR}" || :
|
||||
mkdir -p "${SAMPLES_DIR}" || :
|
||||
"${PSG}/generate_samples.py" "${WAKE_WORD}" \
|
||||
python "${PROGDIR}/run_generator_with_progress.py" \
|
||||
--generator "${PSG}/generate_samples.py" \
|
||||
--output-dir "${SAMPLES_DIR}" \
|
||||
--max-samples ${SAMPLES} \
|
||||
-- \
|
||||
"${WAKE_WORD}" \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
--max-samples ${SAMPLES} \
|
||||
--batch-size ${BATCH_SIZE} \
|
||||
--output-dir "${SAMPLES_DIR}" 2>&1 | sed -r -e "s/(DEBUG|INFO):__main__:/ /g"
|
||||
--output-dir "${SAMPLES_DIR}"
|
||||
|
||||
generated_files=$(find "${SAMPLES_DIR}" -name '*.wav' | wc -l)
|
||||
if [ "${generated_files}" -ne "${SAMPLES}" ] ; then
|
||||
|
||||
@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /data
|
||||
|
||||
# Recorder port
|
||||
# Trainer UI port
|
||||
EXPOSE 8789
|
||||
|
||||
# Script root
|
||||
@@ -23,7 +23,7 @@ COPY --chown=root:root --chmod=0755 .bashrc /root/
|
||||
COPY --chown=root:root --chmod=0755 \
|
||||
train_wake_word \
|
||||
run_recorder.sh \
|
||||
recorder_server.py \
|
||||
trainer_server.py \
|
||||
requirements.txt \
|
||||
/root/mww-scripts/
|
||||
|
||||
@@ -33,8 +33,8 @@ COPY --chown=root:root cli/ /root/mww-scripts/cli/
|
||||
# Make all CLI scripts executable (avoids "Permission denied")
|
||||
RUN chmod -R a+x /root/mww-scripts/cli
|
||||
|
||||
# Static UI for recorder
|
||||
# Static UI for trainer
|
||||
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
||||
|
||||
# recorder server
|
||||
# trainer server
|
||||
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"]
|
||||
|
||||
@@ -18,7 +18,7 @@ FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
|
||||
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
|
||||
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
|
||||
|
||||
echo "microWakeWord Recorder (Docker)"
|
||||
echo "microWakeWord Trainer UI (Docker)"
|
||||
echo "-> ROOTDIR: ${ROOTDIR}"
|
||||
echo "-> DATA_DIR: ${DATA_DIR}"
|
||||
echo "-> URL: http://localhost:${PORT}/"
|
||||
@@ -26,10 +26,10 @@ echo "-> URL: http://localhost:${PORT}/"
|
||||
mkdir -p "${DATA_DIR}"
|
||||
|
||||
# -----------------------------
|
||||
# Recorder venv (separate)
|
||||
# Trainer UI venv (separate)
|
||||
# -----------------------------
|
||||
if [[ ! -x "${PY}" ]]; then
|
||||
echo "Creating recorder venv: ${VENV_DIR}"
|
||||
echo "Creating trainer UI venv: ${VENV_DIR}"
|
||||
python3 -m venv "${VENV_DIR}"
|
||||
fi
|
||||
|
||||
@@ -37,7 +37,7 @@ fi
|
||||
source "${VENV_DIR}/bin/activate"
|
||||
|
||||
if [[ ! -f "${PIN_FILE}" ]]; then
|
||||
echo "Installing pinned recorder deps"
|
||||
echo "Installing pinned trainer UI deps"
|
||||
${PIP} install -U pip setuptools wheel
|
||||
${PIP} install \
|
||||
"fastapi==${FASTAPI_VERSION}" \
|
||||
@@ -45,20 +45,20 @@ if [[ ! -f "${PIN_FILE}" ]]; then
|
||||
"python-multipart==${PY_MULTIPART_VERSION}"
|
||||
touch "${PIN_FILE}"
|
||||
else
|
||||
echo "Reusing existing recorder venv (no upgrades)"
|
||||
echo "Reusing existing trainer UI venv (no upgrades)"
|
||||
fi
|
||||
|
||||
# -----------------------------
|
||||
# Recorder server env
|
||||
# Trainer server env
|
||||
# -----------------------------
|
||||
export DATA_DIR="${DATA_DIR}"
|
||||
export STATIC_DIR="${ROOTDIR}/static"
|
||||
export PERSONAL_DIR="${DATA_DIR}/personal_samples"
|
||||
|
||||
# IMPORTANT: leave training venv creation to /api/train inside recorder_server.py
|
||||
# IMPORTANT: leave training venv creation to /api/train inside trainer_server.py
|
||||
# but still set TRAIN_CMD so the server knows how to invoke training once ready
|
||||
export TRAIN_CMD="source '${DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir='${DATA_DIR}'"
|
||||
|
||||
echo "Launching uvicorn on ${HOST}:${PORT}"
|
||||
cd "${ROOTDIR}"
|
||||
exec "${VENV_DIR}/bin/uvicorn" recorder_server:app --host "${HOST}" --port "${PORT}"
|
||||
exec "${VENV_DIR}/bin/uvicorn" trainer_server:app --host "${HOST}" --port "${PORT}"
|
||||
|
||||
1291
static/index.html
1291
static/index.html
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,18 @@
|
||||
# recorder_server.py
|
||||
# trainer_server.py
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
@@ -26,6 +31,17 @@ PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samp
|
||||
|
||||
# CLI folder inside repo
|
||||
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
|
||||
PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator"
|
||||
PIPER_VOICES_DIR = PIPER_ROOT / "voices"
|
||||
PIPER_VOICES_INDEX_URL = os.environ.get(
|
||||
"PIPER_VOICES_INDEX_URL",
|
||||
"https://huggingface.co/rhasspy/piper-voices/raw/main/voices.json",
|
||||
)
|
||||
PIPER_VOICES_ROOT_URL = os.environ.get(
|
||||
"PIPER_VOICES_ROOT_URL",
|
||||
"https://huggingface.co/rhasspy/piper-voices/resolve/main",
|
||||
)
|
||||
PIPER_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("PIPER_CATALOG_CACHE_TTL_SECONDS", "900"))
|
||||
|
||||
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
||||
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
||||
@@ -39,13 +55,16 @@ DEFAULT_LANGUAGE = os.environ.get("MWW_LANGUAGE", "en")
|
||||
|
||||
TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10"))
|
||||
SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1"))
|
||||
TARGET_SAMPLE_RATE = 16000
|
||||
TARGET_CHANNELS = 1
|
||||
TARGET_SAMPLE_WIDTH_BYTES = 2
|
||||
|
||||
# Tail lines shown to UI
|
||||
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
|
||||
# Safety cap for reads (bytes) to avoid giant file reads
|
||||
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB
|
||||
|
||||
app = FastAPI(title="microWakeWord Personal Recorder")
|
||||
app = FastAPI(title="microWakeWord Personal Samples")
|
||||
|
||||
STATIC_DIR.mkdir(parents=True, exist_ok=True)
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
@@ -84,6 +103,12 @@ STATE: Dict[str, Any] = {
|
||||
}
|
||||
|
||||
STATE_LOCK = threading.Lock()
|
||||
SAMPLES_LOCK = threading.Lock()
|
||||
PIPER_CATALOG_LOCK = threading.Lock()
|
||||
PIPER_CATALOG_CACHE: Dict[str, Any] = {
|
||||
"fetched_at": 0.0,
|
||||
"entries": None,
|
||||
}
|
||||
|
||||
|
||||
def _reset_personal_samples_dir():
|
||||
@@ -95,6 +120,362 @@ def _reset_personal_samples_dir():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def _list_personal_samples() -> List[str]:
|
||||
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(p.name for p in PERSONAL_DIR.glob("*.wav"))
|
||||
|
||||
|
||||
def _sync_personal_samples_state() -> List[str]:
|
||||
takes = _list_personal_samples()
|
||||
with STATE_LOCK:
|
||||
STATE["takes"] = takes
|
||||
STATE["takes_received"] = len(takes)
|
||||
return takes
|
||||
|
||||
|
||||
def _registered_language_family(language: Dict[str, Any]) -> str:
|
||||
family = str(language.get("family") or "").strip().lower()
|
||||
if family:
|
||||
return family
|
||||
code = str(language.get("code") or "").strip()
|
||||
return code.split("_", 1)[0].lower() if code else ""
|
||||
|
||||
|
||||
def _register_language(
|
||||
languages: Dict[str, Dict[str, Any]],
|
||||
*,
|
||||
family: str,
|
||||
name: str,
|
||||
region: str = "",
|
||||
count: int = 1,
|
||||
):
|
||||
if not family:
|
||||
return
|
||||
entry = languages.setdefault(
|
||||
family,
|
||||
{
|
||||
"code": family,
|
||||
"label": f"{name} ({family})",
|
||||
"name": name,
|
||||
"voice_count": 0,
|
||||
"regions": [],
|
||||
},
|
||||
)
|
||||
entry["voice_count"] += count
|
||||
if region and region not in entry["regions"]:
|
||||
entry["regions"].append(region)
|
||||
|
||||
|
||||
def _fetch_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||
req = Request(
|
||||
PIPER_VOICES_INDEX_URL,
|
||||
headers={"User-Agent": "microWakeWord-Trainer/1.0"},
|
||||
)
|
||||
with urlopen(req, timeout=15) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
return data if isinstance(data, dict) else None
|
||||
|
||||
|
||||
def _load_piper_catalog() -> Optional[Dict[str, Any]]:
|
||||
now = time.time()
|
||||
with PIPER_CATALOG_LOCK:
|
||||
cached = PIPER_CATALOG_CACHE.get("entries")
|
||||
fetched_at = float(PIPER_CATALOG_CACHE.get("fetched_at") or 0.0)
|
||||
if cached is not None and (now - fetched_at) < PIPER_CATALOG_CACHE_TTL_SECONDS:
|
||||
return cached
|
||||
|
||||
try:
|
||||
fresh = _fetch_piper_catalog()
|
||||
except Exception:
|
||||
fresh = None
|
||||
|
||||
with PIPER_CATALOG_LOCK:
|
||||
if fresh is not None:
|
||||
PIPER_CATALOG_CACHE["entries"] = fresh
|
||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||
return fresh
|
||||
if PIPER_CATALOG_CACHE.get("entries") is None:
|
||||
PIPER_CATALOG_CACHE["entries"] = {}
|
||||
PIPER_CATALOG_CACHE["fetched_at"] = now
|
||||
return PIPER_CATALOG_CACHE.get("entries")
|
||||
|
||||
|
||||
def _available_languages() -> List[Dict[str, Any]]:
|
||||
languages: Dict[str, Dict[str, Any]] = {
|
||||
"en": {
|
||||
"code": "en",
|
||||
"label": "English (en)",
|
||||
"name": "English",
|
||||
"voice_count": 1,
|
||||
"regions": [],
|
||||
}
|
||||
}
|
||||
|
||||
if PIPER_VOICES_DIR.exists():
|
||||
for meta_path in sorted(PIPER_VOICES_DIR.glob("*.onnx.json")):
|
||||
try:
|
||||
data = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
language = data.get("language") or {}
|
||||
family = _registered_language_family(language)
|
||||
if not family or family == "en":
|
||||
continue
|
||||
|
||||
name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip()
|
||||
region = str(language.get("country_english") or language.get("region") or "").strip()
|
||||
_register_language(languages, family=family, name=name, region=region, count=1)
|
||||
|
||||
catalog = _load_piper_catalog() or {}
|
||||
for entry in catalog.values():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
language = entry.get("language") or {}
|
||||
family = _registered_language_family(language)
|
||||
if not family or family == "en":
|
||||
continue
|
||||
name = str(language.get("name_english") or language.get("name_native") or family.upper()).strip()
|
||||
region = str(language.get("country_english") or language.get("region") or "").strip()
|
||||
_register_language(languages, family=family, name=name, region=region, count=0)
|
||||
|
||||
ordered = [languages["en"]]
|
||||
ordered.extend(
|
||||
sorted(
|
||||
(entry for code, entry in languages.items() if code != "en"),
|
||||
key=lambda entry: (entry["name"].lower(), entry["code"]),
|
||||
)
|
||||
)
|
||||
return ordered
|
||||
|
||||
|
||||
def _normalize_language(language: str | None) -> str:
|
||||
requested = (language or DEFAULT_LANGUAGE).strip().lower() or DEFAULT_LANGUAGE
|
||||
available_codes = {item["code"] for item in _available_languages()}
|
||||
if requested in available_codes:
|
||||
return requested
|
||||
if DEFAULT_LANGUAGE in available_codes:
|
||||
return DEFAULT_LANGUAGE
|
||||
return "en"
|
||||
|
||||
|
||||
def _catalog_voice_files(language_family: str) -> List[Tuple[str, str]]:
|
||||
if not language_family or language_family == "en":
|
||||
return []
|
||||
|
||||
downloads: Dict[str, str] = {}
|
||||
catalog = _load_piper_catalog() or {}
|
||||
for entry in catalog.values():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
language = entry.get("language") or {}
|
||||
family = _registered_language_family(language)
|
||||
if family != language_family:
|
||||
continue
|
||||
files = entry.get("files") or {}
|
||||
for rel_path in files.keys():
|
||||
if not isinstance(rel_path, str):
|
||||
continue
|
||||
if not (rel_path.endswith(".onnx") or rel_path.endswith(".onnx.json")):
|
||||
continue
|
||||
downloads[Path(rel_path).name] = f"{PIPER_VOICES_ROOT_URL}/{rel_path}?download=true"
|
||||
|
||||
return sorted(downloads.items(), key=lambda item: item[0])
|
||||
|
||||
|
||||
def _download_to_path(url: str, dest_path: Path):
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp")
|
||||
req = Request(url, headers={"User-Agent": "microWakeWord-Trainer/1.0"})
|
||||
with urlopen(req, timeout=60) as resp, open(tmp_path, "wb") as out:
|
||||
shutil.copyfileobj(resp, out)
|
||||
tmp_path.replace(dest_path)
|
||||
|
||||
|
||||
def _ensure_non_english_language_voices(language_family: str, log) -> Dict[str, int]:
|
||||
downloads = _catalog_voice_files(language_family)
|
||||
local_voices = sorted(PIPER_VOICES_DIR.glob(f"{language_family}_*.onnx")) if PIPER_VOICES_DIR.exists() else []
|
||||
if not downloads:
|
||||
if local_voices:
|
||||
log(f"===== Piper Voices ({language_family}) =====")
|
||||
log(f"→ Using {len(local_voices)} installed voice(s) for language '{language_family}'")
|
||||
return {
|
||||
"downloaded_files": 0,
|
||||
"existing_files": len(local_voices),
|
||||
"voices": len(local_voices),
|
||||
}
|
||||
raise RuntimeError(
|
||||
f"No Piper ONNX voices found for language '{language_family}' in the upstream catalog."
|
||||
)
|
||||
|
||||
PIPER_VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
downloaded_files = 0
|
||||
existing_files = 0
|
||||
voice_names = sorted(name for name, _ in downloads if name.endswith(".onnx"))
|
||||
|
||||
log(f"===== Piper Voices ({language_family}) =====")
|
||||
log(f"→ Ensuring {len(voice_names)} voice(s) for language '{language_family}'")
|
||||
|
||||
for file_name, url in downloads:
|
||||
dest_path = PIPER_VOICES_DIR / file_name
|
||||
if dest_path.exists() and dest_path.stat().st_size > 0:
|
||||
existing_files += 1
|
||||
continue
|
||||
log(f"→ Downloading {file_name}")
|
||||
_download_to_path(url, dest_path)
|
||||
downloaded_files += 1
|
||||
|
||||
log(
|
||||
f"✓ Piper voices ready for '{language_family}' "
|
||||
f"({downloaded_files} file(s) downloaded, {existing_files} already present)"
|
||||
)
|
||||
return {
|
||||
"downloaded_files": downloaded_files,
|
||||
"existing_files": existing_files,
|
||||
"voices": len(voice_names),
|
||||
}
|
||||
|
||||
|
||||
def _find_ffmpeg() -> Optional[str]:
|
||||
candidates = [
|
||||
shutil.which("ffmpeg"),
|
||||
"/usr/bin/ffmpeg",
|
||||
"/usr/local/bin/ffmpeg",
|
||||
"/opt/homebrew/bin/ffmpeg",
|
||||
"/opt/homebrew/opt/ffmpeg@7/bin/ffmpeg",
|
||||
"/opt/homebrew/opt/ffmpeg/bin/ffmpeg",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate and Path(candidate).exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _inspect_wav_bytes(data: bytes) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
with wave.open(io.BytesIO(data), "rb") as wf:
|
||||
frames = wf.getnframes()
|
||||
rate = wf.getframerate()
|
||||
duration = (frames / rate) if rate else 0.0
|
||||
return {
|
||||
"container": "wav",
|
||||
"sample_rate": rate,
|
||||
"channels": wf.getnchannels(),
|
||||
"sample_width_bits": wf.getsampwidth() * 8,
|
||||
"compression": wf.getcomptype(),
|
||||
"frames": frames,
|
||||
"duration_s": round(duration, 3),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _is_target_wav(info: Optional[Dict[str, Any]]) -> bool:
|
||||
return bool(
|
||||
info
|
||||
and info.get("container") == "wav"
|
||||
and info.get("sample_rate") == TARGET_SAMPLE_RATE
|
||||
and info.get("channels") == TARGET_CHANNELS
|
||||
and info.get("sample_width_bits") == TARGET_SAMPLE_WIDTH_BYTES * 8
|
||||
and info.get("compression") == "NONE"
|
||||
and info.get("frames", 0) > 0
|
||||
)
|
||||
|
||||
|
||||
def _next_personal_sample_name(original_name: str) -> str:
|
||||
current = _list_personal_samples()
|
||||
next_index = 1
|
||||
for name in current:
|
||||
match = re.match(r"sample_(\d{4})", name)
|
||||
if match:
|
||||
next_index = max(next_index, int(match.group(1)) + 1)
|
||||
|
||||
stem = safe_name(Path(original_name or "sample").stem)
|
||||
suffix = f"_{stem[:32]}" if stem and stem != "wakeword" else ""
|
||||
return f"sample_{next_index:04d}{suffix}.wav"
|
||||
|
||||
|
||||
def _format_hint_from_filename(original_name: str) -> Dict[str, Any]:
|
||||
suffix = (Path(original_name or "").suffix or "").lower().lstrip(".")
|
||||
return {
|
||||
"container": suffix or "unknown",
|
||||
"sample_rate": None,
|
||||
"channels": None,
|
||||
"sample_width_bits": None,
|
||||
"compression": None,
|
||||
"frames": None,
|
||||
"duration_s": None,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_audio_to_target_wav(data: bytes, original_name: str) -> bytes:
|
||||
ffmpeg = _find_ffmpeg()
|
||||
if not ffmpeg:
|
||||
raise RuntimeError(
|
||||
"ffmpeg is required to convert uploads that are not already 16 kHz mono 16-bit PCM WAV."
|
||||
)
|
||||
|
||||
suffix = (Path(original_name or "").suffix or ".audio")
|
||||
with tempfile.TemporaryDirectory(prefix="mww_upload_") as tmpdir:
|
||||
src_path = Path(tmpdir) / f"source{suffix}"
|
||||
dst_path = Path(tmpdir) / "normalized.wav"
|
||||
src_path.write_bytes(data)
|
||||
|
||||
cmd = [
|
||||
ffmpeg,
|
||||
"-y",
|
||||
"-i",
|
||||
str(src_path),
|
||||
"-vn",
|
||||
"-ac",
|
||||
str(TARGET_CHANNELS),
|
||||
"-ar",
|
||||
str(TARGET_SAMPLE_RATE),
|
||||
"-c:a",
|
||||
"pcm_s16le",
|
||||
str(dst_path),
|
||||
]
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if proc.returncode != 0 or not dst_path.exists():
|
||||
err = (proc.stderr or proc.stdout or "ffmpeg conversion failed").strip()
|
||||
raise RuntimeError(err.splitlines()[-1] if err else "ffmpeg conversion failed")
|
||||
|
||||
return dst_path.read_bytes()
|
||||
|
||||
|
||||
def _save_personal_sample(data: bytes, original_name: str, out_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
if not data:
|
||||
raise ValueError("Empty or invalid audio file.")
|
||||
|
||||
original_info = _inspect_wav_bytes(data) or _format_hint_from_filename(original_name)
|
||||
normalized = _is_target_wav(original_info)
|
||||
final_bytes = data if normalized else _normalize_audio_to_target_wav(data, original_name)
|
||||
final_info = _inspect_wav_bytes(final_bytes)
|
||||
|
||||
if not _is_target_wav(final_info):
|
||||
raise ValueError("Uploaded audio could not be normalized to 16 kHz mono 16-bit PCM WAV.")
|
||||
|
||||
with SAMPLES_LOCK:
|
||||
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
final_name = out_name or _next_personal_sample_name(original_name)
|
||||
out_path = PERSONAL_DIR / final_name
|
||||
out_path.write_bytes(final_bytes)
|
||||
|
||||
return {
|
||||
"saved_as": final_name,
|
||||
"converted": not normalized,
|
||||
"original_name": original_name or final_name,
|
||||
"detected_format": original_info,
|
||||
"final_format": final_info,
|
||||
"message": (
|
||||
"Converted to 16 kHz mono 16-bit PCM WAV"
|
||||
if not normalized
|
||||
else "Already in the correct 16 kHz mono 16-bit PCM WAV format"
|
||||
),
|
||||
}
|
||||
|
||||
def _clear_training_log():
|
||||
"""
|
||||
Truncate recorder_training.log for a fresh session.
|
||||
@@ -405,6 +786,8 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
|
||||
try:
|
||||
_ensure_training_venv(log_path)
|
||||
_ensure_training_datasets(log_path)
|
||||
if language != "en":
|
||||
_ensure_non_english_language_voices(language, _append_train_log)
|
||||
|
||||
if wake_word_title:
|
||||
cmd_str = f"{TRAIN_CMD} --language='{language}' '{safe_word}' '{wake_word_title}'"
|
||||
@@ -474,7 +857,8 @@ def start_session(payload: Dict[str, Any]):
|
||||
|
||||
speakers_total = int(payload.get("speakers_total") or SPEAKERS_TOTAL_DEFAULT)
|
||||
takes_per_speaker = int(payload.get("takes_per_speaker") or TAKES_PER_SPEAKER_DEFAULT)
|
||||
language = (payload.get("language") or DEFAULT_LANGUAGE).strip().lower()
|
||||
language = _normalize_language(payload.get("language"))
|
||||
available_languages = _available_languages()
|
||||
|
||||
speakers_total = max(1, min(10, speakers_total))
|
||||
takes_per_speaker = max(1, min(50, takes_per_speaker))
|
||||
@@ -485,10 +869,8 @@ def start_session(payload: Dict[str, Any]):
|
||||
STATE["language"] = language
|
||||
STATE["speakers_total"] = speakers_total
|
||||
STATE["takes_per_speaker"] = takes_per_speaker
|
||||
STATE["takes_received"] = 0
|
||||
STATE["takes"] = []
|
||||
|
||||
_reset_personal_samples_dir()
|
||||
takes = _sync_personal_samples_state()
|
||||
|
||||
# Always wipe log on start_session (even if same wakeword)
|
||||
_clear_training_log()
|
||||
@@ -501,6 +883,9 @@ def start_session(payload: Dict[str, Any]):
|
||||
"speakers_total": speakers_total,
|
||||
"takes_per_speaker": takes_per_speaker,
|
||||
"takes_total": speakers_total * takes_per_speaker,
|
||||
"takes_received": len(takes),
|
||||
"takes": takes,
|
||||
"available_languages": available_languages,
|
||||
"personal_dir": str(PERSONAL_DIR),
|
||||
"data_dir": str(DATA_DIR),
|
||||
}
|
||||
@@ -508,17 +893,22 @@ def start_session(payload: Dict[str, Any]):
|
||||
|
||||
@app.get("/api/session")
|
||||
def get_session():
|
||||
takes = _sync_personal_samples_state()
|
||||
available_languages = _available_languages()
|
||||
with STATE_LOCK:
|
||||
current_language = _normalize_language(STATE["language"])
|
||||
STATE["language"] = current_language
|
||||
return {
|
||||
"ok": True,
|
||||
"raw_phrase": STATE["raw_phrase"],
|
||||
"safe_word": STATE["safe_word"],
|
||||
"language": STATE["language"],
|
||||
"language": current_language,
|
||||
"speakers_total": STATE["speakers_total"],
|
||||
"takes_per_speaker": STATE["takes_per_speaker"],
|
||||
"takes_received": STATE["takes_received"],
|
||||
"takes": list(STATE["takes"]),
|
||||
"takes_received": len(takes),
|
||||
"takes": list(takes),
|
||||
"training": dict(STATE["training"]),
|
||||
"available_languages": available_languages,
|
||||
}
|
||||
|
||||
|
||||
@@ -542,23 +932,34 @@ async def upload_take(
|
||||
if take_index < 1 or take_index > takes_per_speaker:
|
||||
return JSONResponse({"ok": False, "error": f"take_index must be 1..{takes_per_speaker}"}, status_code=400)
|
||||
|
||||
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
out_name = f"speaker{speaker_index:02d}_take{take_index:02d}.wav"
|
||||
out_path = PERSONAL_DIR / out_name
|
||||
|
||||
data = await file.read()
|
||||
if not data or len(data) < 44:
|
||||
return JSONResponse({"ok": False, "error": "Empty/invalid file"}, status_code=400)
|
||||
try:
|
||||
result = _save_personal_sample(data, file.filename or out_name, out_name=out_name)
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
|
||||
|
||||
out_path.write_bytes(data)
|
||||
takes = _sync_personal_samples_state()
|
||||
return {"ok": True, **result, "takes_received": len(takes)}
|
||||
|
||||
|
||||
@app.post("/api/upload_personal_sample")
|
||||
async def upload_personal_sample(file: UploadFile = File(...)):
|
||||
with STATE_LOCK:
|
||||
if out_name not in STATE["takes"]:
|
||||
STATE["takes"].append(out_name)
|
||||
STATE["takes_received"] = len(STATE["takes"])
|
||||
safe_word = STATE["safe_word"]
|
||||
|
||||
return {"ok": True, "saved_as": out_name, "takes_received": STATE["takes_received"]}
|
||||
if not safe_word:
|
||||
return JSONResponse({"ok": False, "error": "No active session. Call /api/start_session first."}, status_code=400)
|
||||
|
||||
data = await file.read()
|
||||
try:
|
||||
result = _save_personal_sample(data, file.filename or "sample")
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
|
||||
|
||||
takes = _sync_personal_samples_state()
|
||||
return {"ok": True, **result, "takes_received": len(takes)}
|
||||
|
||||
|
||||
@app.post("/api/train")
|
||||
@@ -581,27 +982,13 @@ def train_now(payload: Dict[str, Any] = None):
|
||||
if not safe_word:
|
||||
return JSONResponse({"ok": False, "error": "No active session"}, status_code=400)
|
||||
|
||||
min_required = max(1, min(3, takes_total))
|
||||
|
||||
if takes_received == 0 and not allow_no_personal:
|
||||
return JSONResponse(
|
||||
{
|
||||
"ok": False,
|
||||
"error": f"No personal voice samples recorded (0/{takes_total}).",
|
||||
"error": "No personal voice samples uploaded yet.",
|
||||
"code": "NO_PERSONAL_SAMPLES",
|
||||
"message": "You can train without personal voices, or record samples first.",
|
||||
"takes_total": takes_total,
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if 0 < takes_received < min_required:
|
||||
return JSONResponse(
|
||||
{
|
||||
"ok": False,
|
||||
"error": f"Not enough takes yet ({takes_received}/{takes_total}).",
|
||||
"code": "NOT_ENOUGH_TAKES",
|
||||
"min_required": min_required,
|
||||
"message": "You can train without personal voices, or upload samples first.",
|
||||
"takes_total": takes_total,
|
||||
},
|
||||
status_code=400,
|
||||
@@ -614,7 +1001,7 @@ def train_now(payload: Dict[str, Any] = None):
|
||||
"ok": True,
|
||||
"started": True,
|
||||
"safe_word": safe_word,
|
||||
"personal_samples_used": takes_received >= min_required,
|
||||
"personal_samples_used": takes_received > 0,
|
||||
"allow_no_personal": allow_no_personal,
|
||||
}
|
||||
|
||||
@@ -656,13 +1043,12 @@ def train_status():
|
||||
|
||||
tr["log_text"] = "\n".join(new_lines) # ONLY new lines
|
||||
tr["log_tail_preview"] = "\n".join(full_tail) # optional: handy for debugging
|
||||
tr["log_lines"] = full_tail
|
||||
return {"ok": True, "training": tr}
|
||||
|
||||
|
||||
@app.post("/api/reset_recordings")
|
||||
def reset_recordings():
|
||||
_reset_personal_samples_dir()
|
||||
with STATE_LOCK:
|
||||
STATE["takes_received"] = 0
|
||||
STATE["takes"] = []
|
||||
return {"ok": True}
|
||||
takes = _sync_personal_samples_state()
|
||||
return {"ok": True, "takes_received": len(takes), "takes": takes}
|
||||
Reference in New Issue
Block a user