From 196ab8c0e7518c7c1b2f07f2f8832476e118d774 Mon Sep 17 00:00:00 2001 From: MasterPhooey Date: Sat, 16 May 2026 00:32:05 -0500 Subject: [PATCH] Add VAD trimming and Docker publishing --- .github/workflows/docker-publish.yml | 48 ++ .gitignore | 1 + run.sh | 66 ++- static/index.html | 652 +++++++++++++++++++++++++-- trainer_server.py | 197 +++++++- 5 files changed, 914 insertions(+), 50 deletions(-) create mode 100644 .github/workflows/docker-publish.yml diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000..d448336 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,48 @@ +name: Publish Docker Image + +on: + push: + branches: + - main + workflow_dispatch: + +permissions: + contents: read + packages: write + +concurrency: + group: docker-publish-${{ github.ref }} + cancel-in-progress: true + +env: + REGISTRY: ghcr.io + IMAGE_NAME: tatertotterson/microwakeword-trainer-nvidia-docker + +jobs: + docker: + name: Docker image + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push image + uses: docker/build-push-action@v6 + with: + context: . + file: dockerfile + platforms: linux/amd64 + push: true + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest + cache-from: type=gha,scope=mww-trainer-nvidia-docker + cache-to: type=gha,mode=max,scope=mww-trainer-nvidia-docker diff --git a/.gitignore b/.gitignore index 73999fe..30180d5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ personal_samples/* data/ +trim_history/ .DS_Store \ No newline at end of file diff --git a/run.sh b/run.sh index 367ab29..cf8437e 100644 --- a/run.sh +++ b/run.sh @@ -26,6 +26,16 @@ echo "-> URL: http://localhost:${PORT}/" mkdir -p "${DATA_DIR}" +install_ui_deps() { + ${PIP} install \ + "fastapi==${FASTAPI_VERSION}" \ + "uvicorn[standard]==${UVICORN_VERSION}" \ + "python-multipart==${PY_MULTIPART_VERSION}" \ + "esphome==${ESPHOME_VERSION}" \ + "silero-vad>=5.0.0" \ + "numpy>=1.24.0" +} + # ----------------------------- # Trainer UI venv (separate) # ----------------------------- @@ -40,32 +50,54 @@ source "${VENV_DIR}/bin/activate" if [[ ! -f "${PIN_FILE}" ]]; then echo "Installing pinned trainer UI deps" ${PIP} install -U pip setuptools wheel - ${PIP} install \ - "fastapi==${FASTAPI_VERSION}" \ - "uvicorn[standard]==${UVICORN_VERSION}" \ - "python-multipart==${PY_MULTIPART_VERSION}" \ - "esphome==${ESPHOME_VERSION}" + install_ui_deps touch "${PIN_FILE}" else echo "Reusing existing trainer UI venv (no upgrades)" - if ! "${PY}" - "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1 -import importlib.metadata + if ! "${PY}" - "${FASTAPI_VERSION}" "${UVICORN_VERSION}" "${PY_MULTIPART_VERSION}" "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1 +import importlib.metadata as md import sys -expected = sys.argv[1] -installed = importlib.metadata.version("esphome") -raise SystemExit(0 if installed == expected else 1) +fastapi_version, uvicorn_version, multipart_version, esphome_version = sys.argv[1:5] + +def version_tuple(value): + parts = [] + for token in str(value).replace("-", ".").split("."): + if token.isdigit(): + parts.append(int(token)) + else: + digits = "".join(ch for ch in token if ch.isdigit()) + if digits: + parts.append(int(digits)) + break + return tuple(parts) + +exact = { + "fastapi": fastapi_version, + "uvicorn": uvicorn_version, + "python-multipart": multipart_version, + "esphome": esphome_version, +} +minimum = { + "silero-vad": "5.0.0", + "numpy": "1.24.0", +} +present = ("torch", "zeroconf") + +for package, expected in exact.items(): + if md.version(package) != expected: + raise SystemExit(1) +for package, minimum_version in minimum.items(): + if version_tuple(md.version(package)) < version_tuple(minimum_version): + raise SystemExit(1) +for package in present: + md.version(package) PY then - echo "Firmware tab dependencies missing or stale; installing ESPHome firmware dependencies" - ${PIP} install \ - "fastapi==${FASTAPI_VERSION}" \ - "uvicorn[standard]==${UVICORN_VERSION}" \ - "python-multipart==${PY_MULTIPART_VERSION}" \ - "esphome==${ESPHOME_VERSION}" + echo "UI dependencies missing or stale; installing recorder dependencies" + install_ui_deps fi fi - # ----------------------------- # Trainer server env # ----------------------------- diff --git a/static/index.html b/static/index.html index 89bd076..3e0f61c 100644 --- a/static/index.html +++ b/static/index.html @@ -622,6 +622,67 @@ color: var(--orange2); } + .paginationControls { + display: flex; + align-items: center; + justify-content: center; + gap: 16px; + padding: 12px 0; + } + + .paginationControls .pageBtn { + padding: 6px 14px; + border-radius: 6px; + font-size: 13px; + background: rgba(255,255,255,0.06); + border: 1px solid rgba(255,255,255,0.12); + color: var(--text, #fff); + cursor: pointer; + } + + .paginationControls .pageBtn:disabled { + opacity: 0.3; + cursor: default; + } + + .paginationControls .pageInfo { + font-size: 13px; + color: var(--muted, #888); + } + + .paginationControls .pageJump { + font-size: 13px; + color: var(--muted, #888); + display: flex; + align-items: center; + gap: 4px; + } + + .paginationControls .pageInput { + width: 40px; + padding: 4px 6px; + font-size: 13px; + text-align: center; + background: rgba(255,255,255,0.06); + border: 1px solid rgba(255,255,255,0.12); + border-radius: 4px; + color: var(--text, #fff); + } + + .paginationControls .pageJumpBtn { + padding: 4px 10px; + font-size: 13px; + background: rgba(255,255,255,0.1); + border: 1px solid rgba(255,255,255,0.15); + border-radius: 4px; + color: var(--text, #fff); + cursor: pointer; + } + + .paginationControls .pageJumpBtn:hover { + background: rgba(255,255,255,0.18); + } + .tabs { display: flex; flex-wrap: wrap; @@ -960,6 +1021,78 @@ 100% { transform: translateY(0) scale(1) rotate(0deg); } } + .trimOverlay { + position: fixed; inset: 0; padding: 22px; + display: flex; align-items: center; justify-content: center; + background: rgba(4,5,10,0.6); backdrop-filter: blur(10px); + opacity: 0; visibility: hidden; pointer-events: none; + transition: opacity 0.18s ease, visibility 0.18s ease; + z-index: 11000; + } + .trimOverlay.open { opacity: 1; visibility: visible; pointer-events: auto; } + + .trimDialog { + width: min(960px, calc(100vw - 36px)); + max-height: min(90vh, 900px); + display: grid; grid-template-rows: auto 1fr auto auto auto; gap: 12px; + padding: 18px; border-radius: 22px; + border: 1px solid rgba(255,255,255,0.12); + background: linear-gradient(180deg, rgba(17,20,28,0.82), rgba(8,10,16,0.94)); + box-shadow: 0 28px 84px rgba(0,0,0,0.58); + backdrop-filter: blur(18px) saturate(1.12); + } + .trimHeader { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; } + .trimTitle { margin: 0; font-size: 18px; } + .trimHint { margin: 6px 0 0; font-size: 13px; color: var(--muted); } + + .trimCanvasWrap { + position: relative; width: 100%; min-height: 120px; + border-radius: 14px; border: 1px solid rgba(255,255,255,0.08); + background: rgba(0,0,0,0.4); overflow: hidden; + } + .trimCanvas { width: 100%; height: 100%; display: block; } + + .trimHandle { + position: absolute; top: 0; width: 48px; height: 100%; + cursor: ew-resize; pointer-events: auto; touch-action: none; + transform: translateX(-50%); + } + .trimHandle::after { + content: ''; position: absolute; top: 10%; bottom: 10%; left: 50%; + transform: translateX(-50%); width: 3px; border-radius: 2px; + background: var(--orange); box-shadow: 0 0 6px rgba(255,138,42,0.5); + } + .trimHandle::before { + content: ''; position: absolute; top: 50%; left: 50%; + transform: translate(-50%,-50%); width: 10px; height: 24px; + border-radius: 5px; border: 1px solid rgba(255,138,42,0.5); + background: rgba(255,138,42,0.15); + } + + .trimTimeInfo { + display: flex; align-items: center; justify-content: center; + gap: 12px; font-size: 14px; + font-family: ui-monospace, SFMono-Regular, Menlo, monospace; + } + .trimSeparator { color: var(--muted); } + .trimVadInfo { display: flex; align-items: center; gap: 8px; font-size: 12px; } + .trimActions { display: flex; gap: 8px; flex-wrap: wrap; } + .trimActions button { flex: 1; min-width: 120px; } + + .pill.trimBadge { + color: #89d4ff; + border-color: rgba(137,212,255,0.25); + background: rgba(137,212,255,0.08); + } + .trimBtn { + border-color: rgba(255,138,42,0.3); + background: rgba(255,138,42,0.1); + } + .trimBtn:hover { + border-color: rgba(255,138,42,0.5); + background: rgba(255,138,42,0.18); + } + @media (max-width: 720px) { .wrap { padding: 18px 14px 30px; } input[type="text"] { width: 100%; } @@ -1033,6 +1166,15 @@ .trainFooter button { width: 100%; } + .trimOverlay { padding: 8px; } + .trimDialog { + width: 100%; height: 96vh; + padding: 14px; grid-template-rows: auto 1fr auto auto auto; + } + .trimCanvasWrap { min-height: 100px; } + .trimHeader { flex-direction: column; align-items: stretch; } + .trimActions { flex-direction: column; } + .trimActions button { width: 100%; min-width: unset; } } @@ -1218,7 +1360,7 @@
-
+
1
@@ -1241,6 +1383,7 @@
No samples saved yet.
+
@@ -1414,6 +1557,40 @@ + + diff --git a/trainer_server.py b/trainer_server.py index c82b0eb..ac5f65f 100644 --- a/trainer_server.py +++ b/trainer_server.py @@ -37,10 +37,11 @@ STATIC_DIR = Path(os.environ.get("STATIC_DIR", str(ROOT_DIR / "static"))).resolv PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve() CAPTURED_DIR = Path(os.environ.get("CAPTURED_DIR", str(DATA_DIR / "captured_audio"))).resolve() NEGATIVE_DIR = Path(os.environ.get("NEGATIVE_DIR", str(DATA_DIR / "negative_samples"))).resolve() +TRIM_HISTORY_DIR = Path(os.environ.get("TRIM_HISTORY_DIR", str(DATA_DIR / "trim_history"))).resolve() +TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True) TRAINED_WAKE_WORDS_DIR = Path( os.environ.get("TRAINED_WAKE_WORDS_DIR", str(DATA_DIR / "trained_wake_words")) ).resolve() - CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve() PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator" PIPER_VOICES_DIR = PIPER_ROOT / "voices" @@ -169,11 +170,58 @@ FIRMWARE_LOCK = threading.Lock() FIRMWARE_SESSIONS: Dict[str, Dict[str, Any]] = {} ANSI_ESCAPE_RE = re.compile(r"\x1B(?:\[[0-?]*[ -/]*[@-~]|[@-Z\\-_])") +# --- Silero VAD (lazy-loaded) --- +_silero_vad_model = None +_silero_vad_utils = None +_SILERO_VAD_LOCK = threading.Lock() +VAD_SELECTION_PAD_START_S = 0.08 +VAD_SELECTION_PAD_END_S = 0.08 + + +def _load_silero_vad(): + """Lazy-load Silero VAD model on first use. Returns (model, utils).""" + global _silero_vad_model, _silero_vad_utils + if _silero_vad_model is not None: + return _silero_vad_model, _silero_vad_utils + with _SILERO_VAD_LOCK: + if _silero_vad_model is not None: + return _silero_vad_model, _silero_vad_utils + import torch + import silero_vad + model = silero_vad.load_silero_vad() + model.eval() + _silero_vad_model = model + _silero_vad_utils = {"torch": torch} + return model, _silero_vad_utils + + +def _detect_speech_segments(wav_bytes: bytes) -> List[Dict[str, float]]: + """Run Silero VAD on 16 kHz mono WAV bytes. Return {start, end} seconds.""" + model, utils = _load_silero_vad() + torch = utils["torch"] + import numpy as np + from silero_vad.utils_vad import get_speech_timestamps + + with wave.open(io.BytesIO(wav_bytes), "rb") as wf: + raw = wf.readframes(wf.getnframes()) + samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 + audio_tensor = torch.from_numpy(samples) + + timestamps = get_speech_timestamps( + audio_tensor, + model, + sampling_rate=16000, + threshold=0.5, + min_speech_duration_ms=150, + min_silence_duration_ms=100, + return_seconds=True, + ) + return [{"start": round(ts["start"], 3), "end": round(ts["end"], 3)} for ts in timestamps] + class _FirmwareYamlLoader(yaml.SafeLoader): pass - class _FirmwareYamlDumper(yaml.SafeDumper): pass @@ -1009,7 +1057,7 @@ def _list_captured_items() -> List[Dict[str, Any]]: def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]: meta = _load_sidecar_json(audio_path) stat = audio_path.stat() - final_format = meta.get("final_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {} + final_format = meta.get("final_format") or meta.get("detected_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {} return { "bucket": bucket, "saved_as": audio_path.name, @@ -1021,6 +1069,8 @@ def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]: "reviewed_at": meta.get("reviewed_at") or "", "created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), "converted": bool(meta.get("converted")), + "trimmed": bool(meta.get("trimmed")), + "source_file": meta.get("source_file") or "", "final_format": final_format, "message": meta.get("message") or "", "size_bytes": stat.st_size, @@ -1036,9 +1086,10 @@ def _list_sample_items(directory: Path, bucket: str) -> List[Dict[str, Any]]: items.append(_sample_item_from_path(audio_path, bucket)) except Exception: continue + # Untrimmed first (stable sort preserves mtime order within each group). + items.sort(key=lambda x: x.get("trimmed", False)) return items - def _samples_payload() -> Dict[str, Any]: takes = _sync_personal_samples_state() personal_items = _list_sample_items(PERSONAL_DIR, "personal") @@ -2768,7 +2819,143 @@ def delete_sample(bucket: str, file_name: str): _remove_audio_with_sidecar(path) except FileNotFoundError as e: return JSONResponse({"ok": False, "error": str(e)}, status_code=404) - return _samples_payload() + return {"ok": True, "deleted_bucket": bucket, "deleted_file": file_name, "message": f"Deleted {file_name}"} + + +@app.post("/api/samples/{bucket}/{file_name}/vad") +def vad_segments(bucket: str, file_name: str): + bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR} + directory = bucket_map.get(bucket) + if directory is None: + return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404) + try: + path = _resolve_audio_path(directory, file_name) + except FileNotFoundError as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=404) + + wav_bytes = path.read_bytes() + try: + all_segments = _detect_speech_segments(wav_bytes) + except Exception as e: + return JSONResponse({"ok": False, "error": f"VAD failed: {str(e)}"}, status_code=500) + + # Only return the first segment longer than 250 ms. Add deterministic + # padding so VAD guides trimming without clipping quiet wake-word edges. + filtered = [s for s in all_segments if (s["end"] - s["start"]) >= 0.25] + if not filtered: + return {"ok": True, "file_name": file_name, "segments": [], "segment_count": 0} + seg = filtered[0] + info = _inspect_wav_bytes(wav_bytes) or {} + duration_s = float(info.get("duration_s") or 0.0) + start = max(0.0, round(seg["start"] - VAD_SELECTION_PAD_START_S, 3)) + end = round(seg["end"] + VAD_SELECTION_PAD_END_S, 3) + if duration_s > 0: + end = min(duration_s, end) + if end <= start: + end = start + 0.001 + segment = {"start": start, "end": end} + return {"ok": True, "file_name": file_name, "segments": [segment], "segment_count": 1} + + +@app.post("/api/samples/trim") +async def trim_sample_upload( + file: UploadFile = File(...), + bucket: str = Form(...), + source_file: str = Form(...), + start_time: str | None = Form(None), + end_time: str | None = Form(None), +): + bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR} + directory = bucket_map.get(bucket) + if directory is None: + return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404) + + data = await file.read() + if not data: + return JSONResponse({"ok": False, "error": "Empty audio file."}, status_code=400) + + info = _inspect_wav_bytes(data) + if not info: + try: + data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav") + except Exception as e: + return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400) + elif not _is_target_wav(info): + try: + data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav") + except Exception as e: + return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400) + + try: + orig_path = _resolve_audio_path(directory, source_file) + except FileNotFoundError as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=404) + + TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True) + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%f") + backup_name = f"{ts}_{source_file}" + backup_path = TRIM_HISTORY_DIR / backup_name + shutil.copy2(orig_path, backup_path) + + orig_sidecar = _audio_sidecar_path(orig_path) + if orig_sidecar.exists(): + shutil.copy2(orig_sidecar, _audio_sidecar_path(backup_path)) + + orig_path.write_bytes(data) + + old_sidecar = _load_sidecar_json(orig_path) + sidecar = { + **old_sidecar, + "trimmed": True, + "source_file": source_file, + "source_bucket": bucket, + "trim_start_s": float(start_time) if start_time else None, + "trim_end_s": float(end_time) if end_time else None, + "undo_backup_file": backup_name, + } + _write_sidecar_json(orig_path, sidecar) + + updated_item = _sample_item_from_path(orig_path, bucket) + updated_item["trimmed"] = True + updated_item["source_file"] = source_file + return {"ok": True, "updated_sample": updated_item, "message": f"Trimmed {source_file}"} + + +@app.post("/api/samples/revert") +def revert_trim( + bucket: str = Form(...), + file_name: str = Form(...), +): + bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR} + directory = bucket_map.get(bucket) + if directory is None: + return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404) + + try: + file_path = _resolve_audio_path(directory, file_name) + except FileNotFoundError as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=404) + + sidecar = _load_sidecar_json(file_path) + backup_name = sidecar.get("undo_backup_file") + if not backup_name: + return JSONResponse({"ok": False, "error": "No trim backup found for this sample."}, status_code=400) + + backup_path = TRIM_HISTORY_DIR / backup_name + if not backup_path.exists(): + return JSONResponse({"ok": False, "error": "Trim backup file missing."}, status_code=404) + + shutil.copy2(backup_path, file_path) + backup_sidecar = _audio_sidecar_path(backup_path) + if backup_sidecar.exists(): + shutil.copy2(backup_sidecar, _audio_sidecar_path(file_path)) + + backup_path.unlink() + if backup_sidecar.exists(): + backup_sidecar.unlink() + + updated_item = _sample_item_from_path(file_path, bucket) + return {"ok": True, "updated_sample": updated_item, "message": f"Reverted {file_name}"} @app.post("/api/captured_audio/{file_name}/approve_personal")