18 Commits
v4 ... main

Author SHA1 Message Date
MasterPhooey
874f273d0b Bump ESPHome pin to 2026.5.1 2026-06-03 10:21:23 -05:00
MasterPhooey
04249f414d Add new ReSpeaker firmware flasher templates 2026-05-19 15:49:57 -05:00
MasterPhooey
6a0d60d569 Add live wake word URL card 2026-05-19 07:42:20 -05:00
MasterPhooey
8df17599c2 Update Tater repo logo 2026-05-16 09:44:35 -05:00
MasterPhooey
280e8f8de4 Update README logo 2026-05-16 07:59:22 -05:00
Tater Totterson
b582a6cade Update Docker image name in workflow 2026-05-16 01:03:11 -05:00
MasterPhooey
196ab8c0e7 Add VAD trimming and Docker publishing 2026-05-16 00:32:05 -05:00
MasterPhooey
134f607bef 2026.4.3 2026-05-03 09:31:02 -05:00
MasterPhooey
4a9e2f2cde 2026.4.3 2026-05-03 07:55:07 -05:00
MasterPhooey
7c246856df cache update 2026-05-02 09:27:11 -05:00
MasterPhooey
3705dabc09 sat1 cache fix 2026-05-01 21:34:17 -05:00
MasterPhooey
1dcf48209f wake sound 2026-05-01 18:31:13 -05:00
MasterPhooey
4f44bef8d5 build cache 2026-05-01 18:03:37 -05:00
MasterPhooey
98fa879db1 wake sound 2026-05-01 17:01:15 -05:00
MasterPhooey
dfac549430 wake sound 2026-05-01 16:49:57 -05:00
MasterPhooey
775a78326b firmware url fixes 2026-05-01 16:24:36 -05:00
MasterPhooey
429be4cc67 502 2026-04-25 12:48:06 -05:00
Tater Totterson
2e6179ec32 Enhance README with images and link
Added additional images and a link to the README for better presentation.
2026-04-25 10:05:57 -05:00
8 changed files with 1759 additions and 113 deletions

48
.github/workflows/docker-publish.yml vendored Normal file
View File

@@ -0,0 +1,48 @@
name: Publish Docker Image
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: read
packages: write
concurrency:
group: docker-publish-${{ github.ref }}
cancel-in-progress: true
env:
REGISTRY: ghcr.io
IMAGE_NAME: tatertotterson/microwakeword
jobs:
docker:
name: Docker image
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push image
uses: docker/build-push-action@v6
with:
context: .
file: dockerfile
platforms: linux/amd64
push: true
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
cache-from: type=gha,scope=mww-trainer-nvidia-docker
cache-to: type=gha,mode=max,scope=mww-trainer-nvidia-docker

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
personal_samples/* personal_samples/*
data/ data/
trim_history/
.DS_Store .DS_Store

View File

@@ -1,7 +1,11 @@
<div align="center"> <div align="center">
<h1>microWakeWord NVIDIA Docker Trainer UI</h1> <a href="https://taterassistant.com">
<img width="800" alt="microWakeWord NVIDIA trainer screenshot" src="https://github.com/user-attachments/assets/694f4cb7-e4d8-4e2b-80ec-b40fb41cbfff" /> <img src="images/tater-repo-logo.png" alt="microWakeWord Trainer" width="460"/>
</a>
</div> </div>
<h3 align="center">
<a href="https://taterassistant.com">taterassistant.com</a>
</h3>
Train custom microWakeWord models in Docker with NVIDIA/CUDA acceleration, generated Piper samples, device-captured samples, reviewed false-wake negatives, live training logs, and ESPHome firmware flashing. Train custom microWakeWord models in Docker with NVIDIA/CUDA acceleration, generated Piper samples, device-captured samples, reviewed false-wake negatives, live training logs, and ESPHome firmware flashing.

View File

@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# System deps # System deps
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \ python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \
git wget curl unzip patch ca-certificates nano less \ git wget curl unzip patch ninja-build ca-certificates nano less \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& mkdir -p /data && mkdir -p /data

BIN
images/tater-repo-logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 590 KiB

68
run.sh
View File

@@ -17,7 +17,7 @@ PIN_FILE="${VENV_DIR}/.pinned_installed"
FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}" FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}" UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}" PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
ESPHOME_VERSION="${REC_ESPHOME_VERSION:-2026.4.0}" ESPHOME_VERSION="${REC_ESPHOME_VERSION:-2026.5.1}"
echo "microWakeWord Trainer UI (Docker)" echo "microWakeWord Trainer UI (Docker)"
echo "-> ROOTDIR: ${ROOTDIR}" echo "-> ROOTDIR: ${ROOTDIR}"
@@ -26,6 +26,16 @@ echo "-> URL: http://localhost:${PORT}/"
mkdir -p "${DATA_DIR}" mkdir -p "${DATA_DIR}"
install_ui_deps() {
${PIP} install \
"fastapi==${FASTAPI_VERSION}" \
"uvicorn[standard]==${UVICORN_VERSION}" \
"python-multipart==${PY_MULTIPART_VERSION}" \
"esphome==${ESPHOME_VERSION}" \
"silero-vad>=5.0.0" \
"numpy>=1.24.0"
}
# ----------------------------- # -----------------------------
# Trainer UI venv (separate) # Trainer UI venv (separate)
# ----------------------------- # -----------------------------
@@ -40,32 +50,54 @@ source "${VENV_DIR}/bin/activate"
if [[ ! -f "${PIN_FILE}" ]]; then if [[ ! -f "${PIN_FILE}" ]]; then
echo "Installing pinned trainer UI deps" echo "Installing pinned trainer UI deps"
${PIP} install -U pip setuptools wheel ${PIP} install -U pip setuptools wheel
${PIP} install \ install_ui_deps
"fastapi==${FASTAPI_VERSION}" \
"uvicorn[standard]==${UVICORN_VERSION}" \
"python-multipart==${PY_MULTIPART_VERSION}" \
"esphome==${ESPHOME_VERSION}"
touch "${PIN_FILE}" touch "${PIN_FILE}"
else else
echo "Reusing existing trainer UI venv (no upgrades)" echo "Reusing existing trainer UI venv (no upgrades)"
if ! "${PY}" - "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1 if ! "${PY}" - "${FASTAPI_VERSION}" "${UVICORN_VERSION}" "${PY_MULTIPART_VERSION}" "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1
import importlib.metadata import importlib.metadata as md
import sys import sys
expected = sys.argv[1] fastapi_version, uvicorn_version, multipart_version, esphome_version = sys.argv[1:5]
installed = importlib.metadata.version("esphome")
raise SystemExit(0 if installed == expected else 1) def version_tuple(value):
parts = []
for token in str(value).replace("-", ".").split("."):
if token.isdigit():
parts.append(int(token))
else:
digits = "".join(ch for ch in token if ch.isdigit())
if digits:
parts.append(int(digits))
break
return tuple(parts)
exact = {
"fastapi": fastapi_version,
"uvicorn": uvicorn_version,
"python-multipart": multipart_version,
"esphome": esphome_version,
}
minimum = {
"silero-vad": "5.0.0",
"numpy": "1.24.0",
}
present = ("torch", "zeroconf")
for package, expected in exact.items():
if md.version(package) != expected:
raise SystemExit(1)
for package, minimum_version in minimum.items():
if version_tuple(md.version(package)) < version_tuple(minimum_version):
raise SystemExit(1)
for package in present:
md.version(package)
PY PY
then then
echo "Firmware tab dependencies missing or stale; installing ESPHome firmware dependencies" echo "UI dependencies missing or stale; installing recorder dependencies"
${PIP} install \ install_ui_deps
"fastapi==${FASTAPI_VERSION}" \
"uvicorn[standard]==${UVICORN_VERSION}" \
"python-multipart==${PY_MULTIPART_VERSION}" \
"esphome==${ESPHOME_VERSION}"
fi fi
fi fi
# ----------------------------- # -----------------------------
# Trainer server env # Trainer server env
# ----------------------------- # -----------------------------

File diff suppressed because it is too large Load Diff

View File

@@ -37,10 +37,11 @@ STATIC_DIR = Path(os.environ.get("STATIC_DIR", str(ROOT_DIR / "static"))).resolv
PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve() PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve()
CAPTURED_DIR = Path(os.environ.get("CAPTURED_DIR", str(DATA_DIR / "captured_audio"))).resolve() CAPTURED_DIR = Path(os.environ.get("CAPTURED_DIR", str(DATA_DIR / "captured_audio"))).resolve()
NEGATIVE_DIR = Path(os.environ.get("NEGATIVE_DIR", str(DATA_DIR / "negative_samples"))).resolve() NEGATIVE_DIR = Path(os.environ.get("NEGATIVE_DIR", str(DATA_DIR / "negative_samples"))).resolve()
TRIM_HISTORY_DIR = Path(os.environ.get("TRIM_HISTORY_DIR", str(DATA_DIR / "trim_history"))).resolve()
TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
TRAINED_WAKE_WORDS_DIR = Path( TRAINED_WAKE_WORDS_DIR = Path(
os.environ.get("TRAINED_WAKE_WORDS_DIR", str(DATA_DIR / "trained_wake_words")) os.environ.get("TRAINED_WAKE_WORDS_DIR", str(DATA_DIR / "trained_wake_words"))
).resolve() ).resolve()
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve() CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator" PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator"
PIPER_VOICES_DIR = PIPER_ROOT / "voices" PIPER_VOICES_DIR = PIPER_ROOT / "voices"
@@ -85,12 +86,17 @@ FIRMWARE_MAX_LOG_LINES = int(os.environ.get("FIRMWARE_MAX_LOG_LINES", "500"))
FIRMWARE_GITHUB_OWNER = os.environ.get("FIRMWARE_GITHUB_OWNER", "TaterTotterson") FIRMWARE_GITHUB_OWNER = os.environ.get("FIRMWARE_GITHUB_OWNER", "TaterTotterson")
FIRMWARE_GITHUB_REPO = os.environ.get("FIRMWARE_GITHUB_REPO", "microWakeWords") FIRMWARE_GITHUB_REPO = os.environ.get("FIRMWARE_GITHUB_REPO", "microWakeWords")
FIRMWARE_GITHUB_REF = os.environ.get("FIRMWARE_GITHUB_REF", "main") FIRMWARE_GITHUB_REF = os.environ.get("FIRMWARE_GITHUB_REF", "main")
WAKE_SOUND_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("WAKE_SOUND_CATALOG_CACHE_TTL_SECONDS", "600"))
FIRMWARE_PLATFORMIO_DIR = FIRMWARE_CACHE_DIR / "platformio" FIRMWARE_PLATFORMIO_DIR = FIRMWARE_CACHE_DIR / "platformio"
FIRMWARE_HOME_DIR = FIRMWARE_CACHE_DIR / "home" FIRMWARE_HOME_DIR = FIRMWARE_CACHE_DIR / "home"
FIRMWARE_XDG_CACHE_DIR = FIRMWARE_CACHE_DIR / "cache" FIRMWARE_XDG_CACHE_DIR = FIRMWARE_CACHE_DIR / "cache"
FIRMWARE_ESPHOME_DATA_DIR = FIRMWARE_CACHE_DIR / "esphome_data"
FIRMWARE_PROFILE_FILE = Path( FIRMWARE_PROFILE_FILE = Path(
os.environ.get("FIRMWARE_PROFILE_FILE", str(FIRMWARE_CACHE_DIR / "profiles.json")) os.environ.get("FIRMWARE_PROFILE_FILE", str(FIRMWARE_CACHE_DIR / "profiles.json"))
).resolve() ).resolve()
WAKE_SOUND_MANIFEST_PATHS = ("wake_sound_manifest.json", "wake-sound-manifest.json")
WAKE_SOUND_CATALOG_CACHE: Dict[str, Any] = {"ts": 0.0, "payload": {}}
WAKE_SOUND_CATALOG_LOCK = threading.Lock()
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400")) TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024)))
@@ -113,6 +119,33 @@ FIRMWARE_TEMPLATE_SPECS = (
"fixed_keys": {"node_name"}, "fixed_keys": {"node_name"},
"auto_keys": {"ha_voice_ip"}, "auto_keys": {"ha_voice_ip"},
}, },
{
"key": "respeaker_lite",
"label": "ReSpeaker Lite (respeakerLite-TaterTimer.yaml)",
"path": "respeakerLite-TaterTimer.yaml",
"identity_key": "device_name",
"friendly_key": "friendly_name",
"fixed_keys": {"device_name"},
"auto_keys": {"ha_voice_ip"},
},
{
"key": "koala",
"label": "Koala Satellite (koala-TaterTimer.yaml)",
"path": "koala-TaterTimer.yaml",
"identity_key": "device_name",
"friendly_key": "friendly_name",
"fixed_keys": {"device_name"},
"auto_keys": {"ha_voice_ip"},
},
{
"key": "respeaker_xvf3800",
"label": "ReSpeaker XVF3800 (respeakerXVF3800-TaterTimer.yaml)",
"path": "respeakerXVF3800-TaterTimer.yaml",
"identity_key": "device_name",
"friendly_key": "friendly_name",
"fixed_keys": {"device_name"},
"auto_keys": {"ha_voice_ip"},
},
) )
app = FastAPI(title="microWakeWord Personal Samples") app = FastAPI(title="microWakeWord Personal Samples")
@@ -164,11 +197,58 @@ FIRMWARE_LOCK = threading.Lock()
FIRMWARE_SESSIONS: Dict[str, Dict[str, Any]] = {} FIRMWARE_SESSIONS: Dict[str, Dict[str, Any]] = {}
ANSI_ESCAPE_RE = re.compile(r"\x1B(?:\[[0-?]*[ -/]*[@-~]|[@-Z\\-_])") ANSI_ESCAPE_RE = re.compile(r"\x1B(?:\[[0-?]*[ -/]*[@-~]|[@-Z\\-_])")
# --- Silero VAD (lazy-loaded) ---
_silero_vad_model = None
_silero_vad_utils = None
_SILERO_VAD_LOCK = threading.Lock()
VAD_SELECTION_PAD_START_S = 0.08
VAD_SELECTION_PAD_END_S = 0.08
def _load_silero_vad():
"""Lazy-load Silero VAD model on first use. Returns (model, utils)."""
global _silero_vad_model, _silero_vad_utils
if _silero_vad_model is not None:
return _silero_vad_model, _silero_vad_utils
with _SILERO_VAD_LOCK:
if _silero_vad_model is not None:
return _silero_vad_model, _silero_vad_utils
import torch
import silero_vad
model = silero_vad.load_silero_vad()
model.eval()
_silero_vad_model = model
_silero_vad_utils = {"torch": torch}
return model, _silero_vad_utils
def _detect_speech_segments(wav_bytes: bytes) -> List[Dict[str, float]]:
"""Run Silero VAD on 16 kHz mono WAV bytes. Return {start, end} seconds."""
model, utils = _load_silero_vad()
torch = utils["torch"]
import numpy as np
from silero_vad.utils_vad import get_speech_timestamps
with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
raw = wf.readframes(wf.getnframes())
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(samples)
timestamps = get_speech_timestamps(
audio_tensor,
model,
sampling_rate=16000,
threshold=0.5,
min_speech_duration_ms=150,
min_silence_duration_ms=100,
return_seconds=True,
)
return [{"start": round(ts["start"], 3), "end": round(ts["end"], 3)} for ts in timestamps]
class _FirmwareYamlLoader(yaml.SafeLoader): class _FirmwareYamlLoader(yaml.SafeLoader):
pass pass
class _FirmwareYamlDumper(yaml.SafeDumper): class _FirmwareYamlDumper(yaml.SafeDumper):
pass pass
@@ -1004,7 +1084,7 @@ def _list_captured_items() -> List[Dict[str, Any]]:
def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]: def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]:
meta = _load_sidecar_json(audio_path) meta = _load_sidecar_json(audio_path)
stat = audio_path.stat() stat = audio_path.stat()
final_format = meta.get("final_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {} final_format = meta.get("final_format") or meta.get("detected_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {}
return { return {
"bucket": bucket, "bucket": bucket,
"saved_as": audio_path.name, "saved_as": audio_path.name,
@@ -1016,6 +1096,8 @@ def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]:
"reviewed_at": meta.get("reviewed_at") or "", "reviewed_at": meta.get("reviewed_at") or "",
"created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), "created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
"converted": bool(meta.get("converted")), "converted": bool(meta.get("converted")),
"trimmed": bool(meta.get("trimmed")),
"source_file": meta.get("source_file") or "",
"final_format": final_format, "final_format": final_format,
"message": meta.get("message") or "", "message": meta.get("message") or "",
"size_bytes": stat.st_size, "size_bytes": stat.st_size,
@@ -1031,9 +1113,10 @@ def _list_sample_items(directory: Path, bucket: str) -> List[Dict[str, Any]]:
items.append(_sample_item_from_path(audio_path, bucket)) items.append(_sample_item_from_path(audio_path, bucket))
except Exception: except Exception:
continue continue
# Untrimmed first (stable sort preserves mtime order within each group).
items.sort(key=lambda x: x.get("trimmed", False))
return items return items
def _samples_payload() -> Dict[str, Any]: def _samples_payload() -> Dict[str, Any]:
takes = _sync_personal_samples_state() takes = _sync_personal_samples_state()
personal_items = _list_sample_items(PERSONAL_DIR, "personal") personal_items = _list_sample_items(PERSONAL_DIR, "personal")
@@ -1449,6 +1532,93 @@ def _load_firmware_template_text(spec: Dict[str, Any]) -> tuple[str, str]:
raise RuntimeError(f"Could not download firmware template from {url}: {exc}") from exc raise RuntimeError(f"Could not download firmware template from {url}: {exc}") from exc
def _wake_sound_label_from_slug(slug: str) -> str:
text = str(slug or "").strip()
if not text:
return "Wake Sound"
return re.sub(r"[_\-.]+", " ", text).strip().title() or "Wake Sound"
def _wake_sound_entries_from_manifest(payload: Any) -> List[Dict[str, str]]:
rows: List[Any] = []
if isinstance(payload, list):
rows = list(payload)
elif isinstance(payload, dict):
for key in ("entries", "wake_sounds", "sounds", "audio", "items"):
candidate = payload.get(key)
if isinstance(candidate, list):
rows = list(candidate)
break
entries: List[Dict[str, str]] = []
seen = set()
for row in rows:
if not isinstance(row, dict):
continue
url = str(
row.get("url")
or row.get("download_url")
or row.get("audio_url")
or row.get("sound_url")
or row.get("wake_sound_url")
or row.get("wake_word_triggered_sound_file")
or ""
).strip()
path = str(row.get("path") or "").strip()
if not url and path:
url = _firmware_raw_url(path)
if not url or url in seen:
continue
seen.add(url)
slug = str(row.get("slug") or row.get("name") or row.get("key") or Path(path or url).stem).strip()
entries.append(
{
"value": url,
"label": str(row.get("label") or row.get("title") or _wake_sound_label_from_slug(slug)).strip(),
}
)
return sorted(entries, key=lambda item: (item["label"].lower(), item["value"]))
def _load_wake_sound_catalog() -> Dict[str, Any]:
now = time.time()
with WAKE_SOUND_CATALOG_LOCK:
cached_ts = float(WAKE_SOUND_CATALOG_CACHE.get("ts") or 0.0)
cached_payload = WAKE_SOUND_CATALOG_CACHE.get("payload")
if isinstance(cached_payload, dict) and (now - cached_ts) < WAKE_SOUND_CATALOG_CACHE_TTL_SECONDS:
return copy.deepcopy(cached_payload)
warnings: List[str] = []
for manifest_path in WAKE_SOUND_MANIFEST_PATHS:
manifest_url = _firmware_raw_url(manifest_path)
try:
payload = json.loads(_fetch_text_url(manifest_url, timeout=20))
entries = _wake_sound_entries_from_manifest(payload)
if entries:
catalog = {"entries": entries, "warning": "", "source_label": manifest_url}
with WAKE_SOUND_CATALOG_LOCK:
WAKE_SOUND_CATALOG_CACHE["ts"] = now
WAKE_SOUND_CATALOG_CACHE["payload"] = copy.deepcopy(catalog)
return catalog
except Exception as exc:
warnings.append(f"{manifest_path}: {exc}")
catalog = {
"entries": [],
"warning": warnings[0] if warnings else "Wake sound catalog unavailable.",
"source_label": "",
}
with WAKE_SOUND_CATALOG_LOCK:
WAKE_SOUND_CATALOG_CACHE["ts"] = now
WAKE_SOUND_CATALOG_CACHE["payload"] = copy.deepcopy(catalog)
return catalog
def _wake_sound_picker_options(catalog: Dict[str, Any]) -> List[Dict[str, str]]:
entries = catalog.get("entries") if isinstance(catalog.get("entries"), list) else []
return [{"value": "__custom__", "label": "Custom URL"}, *[dict(row) for row in entries if isinstance(row, dict)]]
def _extract_substitution_sections(raw_text: str) -> Dict[str, str]: def _extract_substitution_sections(raw_text: str) -> Dict[str, str]:
section_map: Dict[str, str] = {} section_map: Dict[str, str] = {}
in_substitutions = False in_substitutions = False
@@ -1488,19 +1658,115 @@ def _load_firmware_profiles() -> Dict[str, Dict[str, str]]:
return {} return {}
def _save_firmware_profile(template_key: str, values: Dict[str, str]) -> None: def _save_firmware_profile(profile_key: str, values: Dict[str, str]) -> None:
FIRMWARE_PROFILE_FILE.parent.mkdir(parents=True, exist_ok=True) FIRMWARE_PROFILE_FILE.parent.mkdir(parents=True, exist_ok=True)
profiles = _load_firmware_profiles() profiles = _load_firmware_profiles()
profiles[template_key] = {str(key): str(value) for key, value in values.items() if str(key)} profiles[profile_key] = {str(key): str(value) for key, value in values.items() if str(key)}
FIRMWARE_PROFILE_FILE.write_text(json.dumps(profiles, indent=2, sort_keys=True), encoding="utf-8") FIRMWARE_PROFILE_FILE.write_text(json.dumps(profiles, indent=2, sort_keys=True), encoding="utf-8")
def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]) -> Dict[str, str]: def _firmware_profile_target(raw_host: Any = "", raw_port: Any = "") -> tuple[str, str]:
ctx = _load_firmware_template_context(template_key) host = str(raw_host or "").strip()
port = str(raw_port or "").strip()
if "://" in host:
parsed = urlparse(host)
host = parsed.hostname or ""
if not port and parsed.port:
port = str(parsed.port)
host = host.strip().strip("/")
if host.count(":") == 1 and not port:
maybe_host, maybe_port = host.rsplit(":", 1)
if maybe_port.isdigit():
host = maybe_host
port = maybe_port
host = host.strip("[]").strip().lower()
if not host:
return "", ""
with contextlib.suppress(Exception):
parsed_port = int(port or FIRMWARE_DEFAULT_OTA_PORT)
if parsed_port == 6053:
parsed_port = FIRMWARE_DEFAULT_OTA_PORT
port = str(parsed_port)
if not port:
port = str(FIRMWARE_DEFAULT_OTA_PORT)
return host, port
def _firmware_profile_key_for_target(raw_host: Any = "", raw_port: Any = "") -> str:
host, port = _firmware_profile_target(raw_host, raw_port)
return f"device:{host}:{port}" if host else ""
def _firmware_profile_key(template_key: Any = "", raw_host: Any = "", raw_port: Any = "") -> str:
target_key = _firmware_profile_key_for_target(raw_host, raw_port)
template = str(template_key or "").strip().lower()
if target_key and template:
return f"{target_key}:template:{template}"
return target_key or template
def _firmware_cache_slug(*parts: Any) -> str:
raw = "__".join(str(part or "").strip() for part in parts if str(part or "").strip())
slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", raw).strip("._-")
return (slug[:96] or "default").lower()
def _firmware_build_cache_path(
template_key: str,
normalized: Dict[str, str],
host: str,
port: Any = None,
identity_key: str = "",
friendly_key: str = "",
) -> Path:
normalized_host, normalized_port = _firmware_profile_target(host, port)
template_slug = _firmware_cache_slug(template_key, "template")
identity_key = str(identity_key or "").strip()
friendly_key = str(friendly_key or "").strip()
device_identity = (
(normalized.get(identity_key) if identity_key else "")
or (normalized.get(friendly_key) if friendly_key else "")
or normalized.get("node_name")
or normalized.get("device_name")
or normalized.get("friendly_name")
or normalized.get("name")
or normalized_host
or "device"
)
target_slug = _firmware_cache_slug(device_identity, normalized_host, normalized_port)
return FIRMWARE_CACHE_DIR / "builds" / template_slug / target_slug
def _load_firmware_profile(template_key: str, profile_key: str = "") -> Dict[str, str]:
profiles = _load_firmware_profiles()
profile = profiles.get(profile_key) if profile_key else None
if isinstance(profile, dict):
return dict(profile)
if profile_key and ":template:" in profile_key:
legacy_device_key = profile_key.split(":template:", 1)[0]
legacy_device = profiles.get(legacy_device_key)
if isinstance(legacy_device, dict):
return dict(legacy_device)
legacy = profiles.get(template_key)
return dict(legacy) if isinstance(legacy, dict) else {}
def _firmware_profile_values_for_template(profile: Dict[str, Any], substitutions: Dict[str, Any]) -> Dict[str, str]:
keep_keys = {str(key or "").strip() for key in substitutions.keys()}
keep_keys.update({"__target_host", "__target_port", "wake_sound_catalog", "wake_word_choice"})
return {
key: str(profile.get(key) or "")
for key in keep_keys
if key and key in profile
}
def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any], profile_key: str = "") -> Dict[str, str]:
ctx = _load_firmware_template_context(template_key, profile_key)
spec = ctx["spec"] spec = ctx["spec"]
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
substitutions = ctx["substitutions"] substitutions = ctx["substitutions"]
normalized: Dict[str, str] = dict(profile) normalized = _firmware_profile_values_for_template(profile, substitutions)
fixed_keys = set(spec.get("fixed_keys") or set()) fixed_keys = set(spec.get("fixed_keys") or set())
identity_key = str(spec.get("identity_key") or "").strip() identity_key = str(spec.get("identity_key") or "").strip()
if identity_key: if identity_key:
@@ -1527,6 +1793,12 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
elif key_text not in normalized: elif key_text not in normalized:
normalized[key_text] = "true" if str(default).lower() == "true" else "false" normalized[key_text] = "true" if str(default).lower() == "true" else "false"
continue continue
if key_text == "wake_word_model_url":
if key_text in values:
normalized[key_text] = _local_trained_wake_word_url(values.get(key_text))
elif key_text not in normalized:
normalized[key_text] = ""
continue
if key_text in values: if key_text in values:
normalized[key_text] = str(values.get(key_text) or "").strip() normalized[key_text] = str(values.get(key_text) or "").strip()
elif key_text not in normalized: elif key_text not in normalized:
@@ -1535,6 +1807,11 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
wake_word_choice = str(values.get("wake_word_choice") or "").strip() wake_word_choice = str(values.get("wake_word_choice") or "").strip()
if wake_word_choice: if wake_word_choice:
normalized["wake_word_choice"] = wake_word_choice normalized["wake_word_choice"] = wake_word_choice
wake_sound_choice = str(values.get("wake_sound_catalog") or "").strip()
if wake_sound_choice:
normalized["wake_sound_catalog"] = wake_sound_choice
if wake_sound_choice != "__custom__" and "wake_word_triggered_sound_file" in substitutions:
normalized["wake_word_triggered_sound_file"] = wake_sound_choice
target_host = str(values.get("__target_host") or "").strip() target_host = str(values.get("__target_host") or "").strip()
target_port = str(values.get("__target_port") or "").strip() target_port = str(values.get("__target_port") or "").strip()
@@ -1548,7 +1825,38 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
return normalized return normalized
def _load_firmware_template_context(template_key: str) -> Dict[str, Any]: def _local_trained_wake_word_url(value: Any) -> str:
url = str(value or "").strip()
return url if "/api/trained_wake_words/" in url else ""
def _selected_trained_wake_word(
trained_wake_words: List[Dict[str, Any]],
profile: Dict[str, Any],
substitutions: Dict[str, Any],
) -> Dict[str, Any] | None:
if not trained_wake_words:
return None
saved_choice = str(profile.get("wake_word_choice") or "").strip()
current_wake_word_name = str(
profile.get("wake_word_name") or _template_default_string(substitutions.get("wake_word_name"))
).strip()
current_wake_word_url = str(profile.get("wake_word_model_url") or "").strip()
def match(predicate):
return next((row for row in trained_wake_words if predicate(row)), None)
return (
match(lambda row: str(row.get("key") or "") == saved_choice)
or match(lambda row: str(row.get("json_url") or "") == current_wake_word_url)
or match(lambda row: str(row.get("model_url") or "") == current_wake_word_url)
or match(lambda row: str(row.get("wake_word_name") or "") == current_wake_word_name)
or trained_wake_words[0]
)
def _load_firmware_template_context(template_key: str, profile_key: str = "") -> Dict[str, Any]:
spec = _firmware_template_spec(template_key) spec = _firmware_template_spec(template_key)
raw_text, source_label = _load_firmware_template_text(spec) raw_text, source_label = _load_firmware_template_text(spec)
parsed = yaml.load(raw_text, Loader=_FirmwareYamlLoader) parsed = yaml.load(raw_text, Loader=_FirmwareYamlLoader)
@@ -1564,12 +1872,12 @@ def _load_firmware_template_context(template_key: str) -> Dict[str, Any]:
"template_doc": parsed, "template_doc": parsed,
"substitutions": dict(substitutions), "substitutions": dict(substitutions),
"sections": _extract_substitution_sections(raw_text), "sections": _extract_substitution_sections(raw_text),
"profile": _load_firmware_profiles().get(str(spec.get("key") or ""), {}), "profile": _load_firmware_profile(str(spec.get("key") or ""), profile_key),
} }
def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dict[str, Any]]: def _firmware_template_fields(template_key: str, base_url: str = "", profile_key: str = "") -> List[Dict[str, Any]]:
ctx = _load_firmware_template_context(template_key) ctx = _load_firmware_template_context(template_key, profile_key)
spec = ctx["spec"] spec = ctx["spec"]
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
fields: List[Dict[str, Any]] = [] fields: List[Dict[str, Any]] = []
@@ -1579,18 +1887,11 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
fixed_keys.add(identity_key) fixed_keys.add(identity_key)
hidden_keys = {"ha_voice_ip"} | set(spec.get("auto_keys") or set()) hidden_keys = {"ha_voice_ip"} | set(spec.get("auto_keys") or set())
trained_wake_words = _list_trained_wake_words(base_url) trained_wake_words = _list_trained_wake_words(base_url)
current_wake_word_name = str( wake_sound_catalog = _load_wake_sound_catalog()
profile.get("wake_word_name") or _template_default_string(ctx["substitutions"].get("wake_word_name")) selected_wake_word_row = _selected_trained_wake_word(trained_wake_words, profile, ctx["substitutions"])
).strip() selected_wake_word = str(selected_wake_word_row.get("key") or "") if selected_wake_word_row else ""
saved_wake_word_choice = str(profile.get("wake_word_choice") or "").strip()
selected_wake_word = next(
(row["key"] for row in trained_wake_words if row.get("key") == saved_wake_word_choice),
"",
) or next(
(row["key"] for row in trained_wake_words if row.get("wake_word_name") == current_wake_word_name),
"",
)
wake_picker_added = False wake_picker_added = False
wake_sound_picker_added = False
for key, raw_value in ctx["substitutions"].items(): for key, raw_value in ctx["substitutions"].items():
key_text = str(key or "").strip() key_text = str(key or "").strip()
@@ -1616,6 +1917,35 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
) )
wake_picker_added = True wake_picker_added = True
if key_text == "wake_word_triggered_sound_file" and not wake_sound_picker_added:
wake_sound_entries = wake_sound_catalog.get("entries") if isinstance(wake_sound_catalog.get("entries"), list) else []
current_sound_url = str(profile.get(key_text) or _template_default_string(raw_value) or "").strip()
saved_sound_choice = str(profile.get("wake_sound_catalog") or "").strip()
available_sound_urls = {str(row.get("value") or "") for row in wake_sound_entries if isinstance(row, dict)}
if saved_sound_choice in available_sound_urls or saved_sound_choice == "__custom__":
picker_value = saved_sound_choice
else:
picker_value = current_sound_url if current_sound_url in available_sound_urls else "__custom__"
description = (
f"Choose from {len(wake_sound_entries)} prebuilt wake sounds, or leave this on Custom URL and paste your own audio URL below."
if wake_sound_entries
else "Prebuilt wake-sound catalog is unavailable right now. You can still paste any custom audio URL below."
)
if wake_sound_catalog.get("warning") and not wake_sound_entries:
description = f"{description} {wake_sound_catalog['warning']}".strip()
fields.append(
{
"key": "wake_sound_catalog",
"label": "Prebuilt Wake Sound",
"type": "wake_sound_select",
"value": picker_value,
"options": _wake_sound_picker_options(wake_sound_catalog),
"description": description,
"section": "Wake Sound",
}
)
wake_sound_picker_added = True
default = _template_default_string(raw_value) default = _template_default_string(raw_value)
saved = str(profile.get(key_text) or "") saved = str(profile.get(key_text) or "")
field_type = "text" field_type = "text"
@@ -1640,11 +1970,20 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
placeholder = "Your Wi-Fi SSID" placeholder = "Your Wi-Fi SSID"
description = "Required before build + flash." description = "Required before build + flash."
elif key_text == "wake_word_model_url": elif key_text == "wake_word_model_url":
placeholder = "https://.../wake_word.json" value = str(selected_wake_word_row.get("json_url") or "") if selected_wake_word_row else ""
placeholder = "Train or select a local wake word first"
description = "Filled from the local trained wake-word picker."
elif key_text == "wake_word_name": elif key_text == "wake_word_name":
if selected_wake_word_row:
value = str(selected_wake_word_row.get("wake_word_name") or selected_wake_word_row.get("key") or "")
placeholder = "hey_tater" placeholder = "hey_tater"
elif key_text == "wake_word_triggered_sound_file":
placeholder = "https://.../wake-sound.mp3"
description = "Pick a prebuilt wake sound above or paste any custom audio URL."
section = ctx["sections"].get(key_text) or "Firmware" section = ctx["sections"].get(key_text) or "Firmware"
if key_text in {"wake_word_name", "wake_word_model_url"}: if key_text == "wake_word_triggered_sound_file":
section = "Wake Sound"
elif key_text in {"wake_word_name", "wake_word_model_url"}:
section = "Micro Wake Word" section = "Micro Wake Word"
elif key_text.endswith("_sound_file"): elif key_text.endswith("_sound_file"):
section = "Sounds" section = "Sounds"
@@ -1678,12 +2017,19 @@ def _esphome_pythonpath() -> str:
return os.pathsep.join(paths) return os.pathsep.join(paths)
def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str, session_id: str) -> tuple[Path, Dict[str, str]]: def _render_firmware_config(
ctx = _load_firmware_template_context(template_key) template_key: str,
values: Dict[str, Any],
host: str,
session_id: str,
port: Any = None,
) -> tuple[Path, Dict[str, str], Path]:
profile_key = _firmware_profile_key(template_key, host, port)
ctx = _load_firmware_template_context(template_key, profile_key)
spec = ctx["spec"] spec = ctx["spec"]
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
substitutions = ctx["substitutions"] substitutions = ctx["substitutions"]
normalized: Dict[str, str] = dict(profile) normalized = _firmware_profile_values_for_template(profile, substitutions)
fixed_keys = set(spec.get("fixed_keys") or set()) fixed_keys = set(spec.get("fixed_keys") or set())
identity_key = str(spec.get("identity_key") or "").strip() identity_key = str(spec.get("identity_key") or "").strip()
if identity_key: if identity_key:
@@ -1702,10 +2048,15 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
normalized[key_text] = "true" if bool(raw_value) else "false" normalized[key_text] = "true" if bool(raw_value) else "false"
elif key_text == "ha_voice_ip": elif key_text == "ha_voice_ip":
normalized[key_text] = host normalized[key_text] = host
elif key_text == "wake_word_model_url":
normalized[key_text] = _local_trained_wake_word_url(raw_value)
else: else:
normalized[key_text] = str(raw_value if raw_value is not None else "").strip() or _template_default_string( normalized[key_text] = str(raw_value if raw_value is not None else "").strip() or _template_default_string(
substitutions.get(key_text) substitutions.get(key_text)
) )
wake_sound_choice = str(values.get("wake_sound_catalog") or "").strip()
if wake_sound_choice and wake_sound_choice != "__custom__" and "wake_word_triggered_sound_file" in substitutions:
normalized["wake_word_triggered_sound_file"] = wake_sound_choice
missing = [] missing = []
if not normalized.get("wifi_ssid"): if not normalized.get("wifi_ssid"):
@@ -1714,22 +2065,36 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
missing.append("Wi-Fi password") missing.append("Wi-Fi password")
if not host: if not host:
missing.append("device IP or hostname") missing.append("device IP or hostname")
if "wake_word_model_url" in substitutions and not normalized.get("wake_word_model_url"):
missing.append("local trained wake word")
if missing: if missing:
raise RuntimeError(f"Missing required firmware values: {', '.join(missing)}.") raise RuntimeError(f"Missing required firmware values: {', '.join(missing)}.")
config = copy.deepcopy(ctx["template_doc"]) config = copy.deepcopy(ctx["template_doc"])
config["substitutions"] = {key: str(normalized.get(key, "")) for key in substitutions.keys()} config["substitutions"] = {key: str(normalized.get(key, "")) for key in substitutions.keys()}
build_path = _firmware_build_cache_path(
str(spec.get("key") or template_key),
normalized,
host,
port,
str(spec.get("identity_key") or ""),
str(spec.get("friendly_key") or ""),
)
esphome_block = config.get("esphome") if isinstance(config.get("esphome"), dict) else None esphome_block = config.get("esphome") if isinstance(config.get("esphome"), dict) else None
if isinstance(esphome_block, dict): if isinstance(esphome_block, dict):
esphome_block["build_path"] = str(FIRMWARE_CACHE_DIR / "builds" / session_id / str(spec.get("key") or template_key)) esphome_block["build_path"] = str(build_path)
config["esphome"] = esphome_block config["esphome"] = esphome_block
session_dir = FIRMWARE_CACHE_DIR / "configs" / session_id session_dir = FIRMWARE_CACHE_DIR / "configs" / session_id
session_dir.mkdir(parents=True, exist_ok=True) session_dir.mkdir(parents=True, exist_ok=True)
config_path = session_dir / f"{str(spec.get('key') or template_key)}.yaml" config_path = session_dir / f"{build_path.parent.name}__{build_path.name}.yaml"
config_path.write_text(yaml.dump(config, Dumper=_FirmwareYamlDumper, sort_keys=False, allow_unicode=True), encoding="utf-8") config_path.write_text(yaml.dump(config, Dumper=_FirmwareYamlDumper, sort_keys=False, allow_unicode=True), encoding="utf-8")
_save_firmware_profile(str(spec.get("key") or template_key), normalized) normalized_host, normalized_port = _firmware_profile_target(host, port)
return config_path, normalized if normalized_host:
normalized["__target_host"] = normalized_host
normalized["__target_port"] = normalized_port
_save_firmware_profile(profile_key or str(spec.get("key") or template_key), normalized)
return config_path, normalized, build_path
def _firmware_session_payload(session_id: str) -> Dict[str, Any]: def _firmware_session_payload(session_id: str) -> Dict[str, Any]:
@@ -1779,11 +2144,13 @@ def _firmware_runner_env(*, include_esphome_pythonpath: bool = False) -> Dict[st
FIRMWARE_HOME_DIR.mkdir(parents=True, exist_ok=True) FIRMWARE_HOME_DIR.mkdir(parents=True, exist_ok=True)
FIRMWARE_XDG_CACHE_DIR.mkdir(parents=True, exist_ok=True) FIRMWARE_XDG_CACHE_DIR.mkdir(parents=True, exist_ok=True)
FIRMWARE_PLATFORMIO_DIR.mkdir(parents=True, exist_ok=True) FIRMWARE_PLATFORMIO_DIR.mkdir(parents=True, exist_ok=True)
FIRMWARE_ESPHOME_DATA_DIR.mkdir(parents=True, exist_ok=True)
env = os.environ.copy() env = os.environ.copy()
env.pop("PYTHONPATH", None) env.pop("PYTHONPATH", None)
env["PYTHONUNBUFFERED"] = "1" env["PYTHONUNBUFFERED"] = "1"
env["HOME"] = str(FIRMWARE_HOME_DIR) env["HOME"] = str(FIRMWARE_HOME_DIR)
env["XDG_CACHE_HOME"] = str(FIRMWARE_XDG_CACHE_DIR) env["XDG_CACHE_HOME"] = str(FIRMWARE_XDG_CACHE_DIR)
env["ESPHOME_DATA_DIR"] = str(FIRMWARE_ESPHOME_DATA_DIR)
env["PLATFORMIO_CORE_DIR"] = str(FIRMWARE_PLATFORMIO_DIR) env["PLATFORMIO_CORE_DIR"] = str(FIRMWARE_PLATFORMIO_DIR)
env["PLATFORMIO_CACHE_DIR"] = str(FIRMWARE_PLATFORMIO_DIR / "cache") env["PLATFORMIO_CACHE_DIR"] = str(FIRMWARE_PLATFORMIO_DIR / "cache")
if include_esphome_pythonpath: if include_esphome_pythonpath:
@@ -1947,7 +2314,7 @@ def _run_firmware_build_flash_background(session_id: str):
return return
try: try:
config_path, _normalized = _render_firmware_config(template_key, values, host, session_id) config_path, normalized, build_path = _render_firmware_config(template_key, values, host, session_id, port)
except Exception as exc: except Exception as exc:
_append_firmware_log(session_id, f"✗ Failed to prepare firmware config: {exc}") _append_firmware_log(session_id, f"✗ Failed to prepare firmware config: {exc}")
with FIRMWARE_LOCK: with FIRMWARE_LOCK:
@@ -1974,6 +2341,9 @@ def _run_firmware_build_flash_background(session_id: str):
_append_firmware_log(session_id, f"→ Template: {template_key}") _append_firmware_log(session_id, f"→ Template: {template_key}")
_append_firmware_log(session_id, f"→ Device: {host}:{port}") _append_firmware_log(session_id, f"→ Device: {host}:{port}")
_append_firmware_log(session_id, f"→ Config: {config_path}") _append_firmware_log(session_id, f"→ Config: {config_path}")
_append_firmware_log(session_id, f"→ Build cache: {build_path}")
if normalized.get("wake_word_triggered_sound_file"):
_append_firmware_log(session_id, f"→ Wake sound: {normalized['wake_word_triggered_sound_file']}")
_append_firmware_log(session_id, "→ Running: " + " ".join(command)) _append_firmware_log(session_id, "→ Running: " + " ".join(command))
try: try:
@@ -2012,6 +2382,12 @@ def _run_firmware_build_flash_background(session_id: str):
"Tip: PlatformIO's ESP-IDF Python environment crashed while installing dependencies. " "Tip: PlatformIO's ESP-IDF Python environment crashed while installing dependencies. "
"Run Clean Build Files once, then retry the flash.", "Run Clean Build Files once, then retry the flash.",
) )
if "pioarduino/registry" in joined_lines and "ninja-" in joined_lines and "status code '502'" in joined_lines:
_append_firmware_log(
session_id,
"Tip: GitHub returned a 502 while PlatformIO was downloading Ninja. "
"This is an upstream package download failure; retry the build in a few minutes.",
)
_append_firmware_log(session_id, f"✗ Firmware build + flash failed (exit_code={rc})") _append_firmware_log(session_id, f"✗ Firmware build + flash failed (exit_code={rc})")
with FIRMWARE_LOCK: with FIRMWARE_LOCK:
live = FIRMWARE_SESSIONS.get(session_id) live = FIRMWARE_SESSIONS.get(session_id)
@@ -2470,7 +2846,143 @@ def delete_sample(bucket: str, file_name: str):
_remove_audio_with_sidecar(path) _remove_audio_with_sidecar(path)
except FileNotFoundError as e: except FileNotFoundError as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=404) return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
return _samples_payload() return {"ok": True, "deleted_bucket": bucket, "deleted_file": file_name, "message": f"Deleted {file_name}"}
@app.post("/api/samples/{bucket}/{file_name}/vad")
def vad_segments(bucket: str, file_name: str):
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
directory = bucket_map.get(bucket)
if directory is None:
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
try:
path = _resolve_audio_path(directory, file_name)
except FileNotFoundError as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
wav_bytes = path.read_bytes()
try:
all_segments = _detect_speech_segments(wav_bytes)
except Exception as e:
return JSONResponse({"ok": False, "error": f"VAD failed: {str(e)}"}, status_code=500)
# Only return the first segment longer than 250 ms. Add deterministic
# padding so VAD guides trimming without clipping quiet wake-word edges.
filtered = [s for s in all_segments if (s["end"] - s["start"]) >= 0.25]
if not filtered:
return {"ok": True, "file_name": file_name, "segments": [], "segment_count": 0}
seg = filtered[0]
info = _inspect_wav_bytes(wav_bytes) or {}
duration_s = float(info.get("duration_s") or 0.0)
start = max(0.0, round(seg["start"] - VAD_SELECTION_PAD_START_S, 3))
end = round(seg["end"] + VAD_SELECTION_PAD_END_S, 3)
if duration_s > 0:
end = min(duration_s, end)
if end <= start:
end = start + 0.001
segment = {"start": start, "end": end}
return {"ok": True, "file_name": file_name, "segments": [segment], "segment_count": 1}
@app.post("/api/samples/trim")
async def trim_sample_upload(
file: UploadFile = File(...),
bucket: str = Form(...),
source_file: str = Form(...),
start_time: str | None = Form(None),
end_time: str | None = Form(None),
):
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
directory = bucket_map.get(bucket)
if directory is None:
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
data = await file.read()
if not data:
return JSONResponse({"ok": False, "error": "Empty audio file."}, status_code=400)
info = _inspect_wav_bytes(data)
if not info:
try:
data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav")
except Exception as e:
return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400)
elif not _is_target_wav(info):
try:
data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav")
except Exception as e:
return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400)
try:
orig_path = _resolve_audio_path(directory, source_file)
except FileNotFoundError as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%f")
backup_name = f"{ts}_{source_file}"
backup_path = TRIM_HISTORY_DIR / backup_name
shutil.copy2(orig_path, backup_path)
orig_sidecar = _audio_sidecar_path(orig_path)
if orig_sidecar.exists():
shutil.copy2(orig_sidecar, _audio_sidecar_path(backup_path))
orig_path.write_bytes(data)
old_sidecar = _load_sidecar_json(orig_path)
sidecar = {
**old_sidecar,
"trimmed": True,
"source_file": source_file,
"source_bucket": bucket,
"trim_start_s": float(start_time) if start_time else None,
"trim_end_s": float(end_time) if end_time else None,
"undo_backup_file": backup_name,
}
_write_sidecar_json(orig_path, sidecar)
updated_item = _sample_item_from_path(orig_path, bucket)
updated_item["trimmed"] = True
updated_item["source_file"] = source_file
return {"ok": True, "updated_sample": updated_item, "message": f"Trimmed {source_file}"}
@app.post("/api/samples/revert")
def revert_trim(
bucket: str = Form(...),
file_name: str = Form(...),
):
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
directory = bucket_map.get(bucket)
if directory is None:
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
try:
file_path = _resolve_audio_path(directory, file_name)
except FileNotFoundError as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
sidecar = _load_sidecar_json(file_path)
backup_name = sidecar.get("undo_backup_file")
if not backup_name:
return JSONResponse({"ok": False, "error": "No trim backup found for this sample."}, status_code=400)
backup_path = TRIM_HISTORY_DIR / backup_name
if not backup_path.exists():
return JSONResponse({"ok": False, "error": "Trim backup file missing."}, status_code=404)
shutil.copy2(backup_path, file_path)
backup_sidecar = _audio_sidecar_path(backup_path)
if backup_sidecar.exists():
shutil.copy2(backup_sidecar, _audio_sidecar_path(file_path))
backup_path.unlink()
if backup_sidecar.exists():
backup_sidecar.unlink()
updated_item = _sample_item_from_path(file_path, bucket)
return {"ok": True, "updated_sample": updated_item, "message": f"Reverted {file_name}"}
@app.post("/api/captured_audio/{file_name}/approve_personal") @app.post("/api/captured_audio/{file_name}/approve_personal")
@@ -2512,28 +3024,30 @@ def firmware_devices():
@app.get("/api/firmware/templates") @app.get("/api/firmware/templates")
def firmware_templates(request: Request): def firmware_templates(request: Request, target_host: str = "", target_port: str = ""):
templates = [] templates = []
warnings = [] warnings = []
base_url = _request_base_url(request) base_url = _request_base_url(request)
wake_words = _list_trained_wake_words(base_url) wake_words = _list_trained_wake_words(base_url)
profiles = _load_firmware_profiles() selected_host, selected_port = _firmware_profile_target(target_host, target_port)
for spec in FIRMWARE_TEMPLATE_SPECS: for spec in FIRMWARE_TEMPLATE_SPECS:
key = str(spec.get("key") or "") key = str(spec.get("key") or "")
profile = profiles.get(key, {}) profile_key = _firmware_profile_key(key, target_host, target_port)
target_port = str(profile.get("__target_port") or "") profile = _load_firmware_profile(key, profile_key)
if target_port == "6053": row_target_host = selected_host or str(profile.get("__target_host") or "")
target_port = str(FIRMWARE_DEFAULT_OTA_PORT) row_target_port = selected_port or str(profile.get("__target_port") or "")
if row_target_port == "6053":
row_target_port = str(FIRMWARE_DEFAULT_OTA_PORT)
row = { row = {
"value": key, "value": key,
"label": str(spec.get("label") or key), "label": str(spec.get("label") or key),
"source_url": _firmware_raw_url(str(spec.get("path") or "")), "source_url": _firmware_raw_url(str(spec.get("path") or "")),
"target_host": str(profile.get("__target_host") or ""), "target_host": row_target_host,
"target_port": target_port, "target_port": row_target_port,
"fields": [], "fields": [],
} }
try: try:
row["fields"] = _firmware_template_fields(key, base_url) row["fields"] = _firmware_template_fields(key, base_url, profile_key)
except Exception as exc: except Exception as exc:
warnings.append(f"{row['label']}: {exc}") warnings.append(f"{row['label']}: {exc}")
templates.append(row) templates.append(row)
@@ -2553,11 +3067,12 @@ def firmware_profile(payload: Dict[str, Any]):
template_key = str(body.get("template_key") or "").strip() template_key = str(body.get("template_key") or "").strip()
_firmware_template_spec(template_key) _firmware_template_spec(template_key)
values = body.get("values") if isinstance(body.get("values"), dict) else {} values = body.get("values") if isinstance(body.get("values"), dict) else {}
saved = _normalize_firmware_profile_update(template_key, values) profile_key = _firmware_profile_key(template_key, values.get("__target_host"), values.get("__target_port"))
_save_firmware_profile(template_key, saved) saved = _normalize_firmware_profile_update(template_key, values, profile_key)
_save_firmware_profile(profile_key or template_key, saved)
except Exception as e: except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=400) return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
return {"ok": True, "template_key": template_key, "saved_fields": sorted(saved.keys())} return {"ok": True, "template_key": template_key, "profile_key": profile_key or template_key, "saved_fields": sorted(saved.keys())}
@app.get("/api/trained_wake_words/catalog") @app.get("/api/trained_wake_words/catalog")
@@ -2625,7 +3140,7 @@ def firmware_clean():
return JSONResponse({"ok": False, "error": f"Wait for active firmware session(s) to finish: {', '.join(active[:3])}."}, status_code=400) return JSONResponse({"ok": False, "error": f"Wait for active firmware session(s) to finish: {', '.join(active[:3])}."}, status_code=400)
removed = [] removed = []
for child in ("configs", "builds", "platformio", "home", "cache"): for child in ("configs", "builds", "platformio", "home", "cache", "esphome_data"):
path = FIRMWARE_CACHE_DIR / child path = FIRMWARE_CACHE_DIR / child
if path.exists(): if path.exists():
shutil.rmtree(path, ignore_errors=True) shutil.rmtree(path, ignore_errors=True)