mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
874f273d0b | ||
|
|
04249f414d | ||
|
|
6a0d60d569 | ||
|
|
8df17599c2 | ||
|
|
280e8f8de4 | ||
|
|
b582a6cade | ||
|
|
196ab8c0e7 | ||
|
|
134f607bef | ||
|
|
4a9e2f2cde | ||
|
|
7c246856df | ||
|
|
3705dabc09 | ||
|
|
1dcf48209f | ||
|
|
4f44bef8d5 | ||
|
|
98fa879db1 | ||
|
|
dfac549430 | ||
|
|
775a78326b | ||
|
|
429be4cc67 | ||
|
|
2e6179ec32 |
48
.github/workflows/docker-publish.yml
vendored
Normal file
48
.github/workflows/docker-publish.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Publish Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
concurrency:
|
||||
group: docker-publish-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: tatertotterson/microwakeword
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
name: Docker image
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: dockerfile
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||
cache-from: type=gha,scope=mww-trainer-nvidia-docker
|
||||
cache-to: type=gha,mode=max,scope=mww-trainer-nvidia-docker
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
personal_samples/*
|
||||
data/
|
||||
trim_history/
|
||||
.DS_Store
|
||||
@@ -1,7 +1,11 @@
|
||||
<div align="center">
|
||||
<h1>microWakeWord NVIDIA Docker Trainer UI</h1>
|
||||
<img width="800" alt="microWakeWord NVIDIA trainer screenshot" src="https://github.com/user-attachments/assets/694f4cb7-e4d8-4e2b-80ec-b40fb41cbfff" />
|
||||
<a href="https://taterassistant.com">
|
||||
<img src="images/tater-repo-logo.png" alt="microWakeWord Trainer" width="460"/>
|
||||
</a>
|
||||
</div>
|
||||
<h3 align="center">
|
||||
<a href="https://taterassistant.com">taterassistant.com</a>
|
||||
</h3>
|
||||
|
||||
Train custom microWakeWord models in Docker with NVIDIA/CUDA acceleration, generated Piper samples, device-captured samples, reviewed false-wake negatives, live training logs, and ESPHome firmware flashing.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
# System deps
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \
|
||||
git wget curl unzip patch ca-certificates nano less \
|
||||
git wget curl unzip patch ninja-build ca-certificates nano less \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /data
|
||||
|
||||
|
||||
BIN
images/tater-repo-logo.png
Normal file
BIN
images/tater-repo-logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 590 KiB |
68
run.sh
68
run.sh
@@ -17,7 +17,7 @@ PIN_FILE="${VENV_DIR}/.pinned_installed"
|
||||
FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
|
||||
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
|
||||
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
|
||||
ESPHOME_VERSION="${REC_ESPHOME_VERSION:-2026.4.0}"
|
||||
ESPHOME_VERSION="${REC_ESPHOME_VERSION:-2026.5.1}"
|
||||
|
||||
echo "microWakeWord Trainer UI (Docker)"
|
||||
echo "-> ROOTDIR: ${ROOTDIR}"
|
||||
@@ -26,6 +26,16 @@ echo "-> URL: http://localhost:${PORT}/"
|
||||
|
||||
mkdir -p "${DATA_DIR}"
|
||||
|
||||
install_ui_deps() {
|
||||
${PIP} install \
|
||||
"fastapi==${FASTAPI_VERSION}" \
|
||||
"uvicorn[standard]==${UVICORN_VERSION}" \
|
||||
"python-multipart==${PY_MULTIPART_VERSION}" \
|
||||
"esphome==${ESPHOME_VERSION}" \
|
||||
"silero-vad>=5.0.0" \
|
||||
"numpy>=1.24.0"
|
||||
}
|
||||
|
||||
# -----------------------------
|
||||
# Trainer UI venv (separate)
|
||||
# -----------------------------
|
||||
@@ -40,32 +50,54 @@ source "${VENV_DIR}/bin/activate"
|
||||
if [[ ! -f "${PIN_FILE}" ]]; then
|
||||
echo "Installing pinned trainer UI deps"
|
||||
${PIP} install -U pip setuptools wheel
|
||||
${PIP} install \
|
||||
"fastapi==${FASTAPI_VERSION}" \
|
||||
"uvicorn[standard]==${UVICORN_VERSION}" \
|
||||
"python-multipart==${PY_MULTIPART_VERSION}" \
|
||||
"esphome==${ESPHOME_VERSION}"
|
||||
install_ui_deps
|
||||
touch "${PIN_FILE}"
|
||||
else
|
||||
echo "Reusing existing trainer UI venv (no upgrades)"
|
||||
if ! "${PY}" - "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1
|
||||
import importlib.metadata
|
||||
if ! "${PY}" - "${FASTAPI_VERSION}" "${UVICORN_VERSION}" "${PY_MULTIPART_VERSION}" "${ESPHOME_VERSION}" <<'PY' >/dev/null 2>&1
|
||||
import importlib.metadata as md
|
||||
import sys
|
||||
|
||||
expected = sys.argv[1]
|
||||
installed = importlib.metadata.version("esphome")
|
||||
raise SystemExit(0 if installed == expected else 1)
|
||||
fastapi_version, uvicorn_version, multipart_version, esphome_version = sys.argv[1:5]
|
||||
|
||||
def version_tuple(value):
|
||||
parts = []
|
||||
for token in str(value).replace("-", ".").split("."):
|
||||
if token.isdigit():
|
||||
parts.append(int(token))
|
||||
else:
|
||||
digits = "".join(ch for ch in token if ch.isdigit())
|
||||
if digits:
|
||||
parts.append(int(digits))
|
||||
break
|
||||
return tuple(parts)
|
||||
|
||||
exact = {
|
||||
"fastapi": fastapi_version,
|
||||
"uvicorn": uvicorn_version,
|
||||
"python-multipart": multipart_version,
|
||||
"esphome": esphome_version,
|
||||
}
|
||||
minimum = {
|
||||
"silero-vad": "5.0.0",
|
||||
"numpy": "1.24.0",
|
||||
}
|
||||
present = ("torch", "zeroconf")
|
||||
|
||||
for package, expected in exact.items():
|
||||
if md.version(package) != expected:
|
||||
raise SystemExit(1)
|
||||
for package, minimum_version in minimum.items():
|
||||
if version_tuple(md.version(package)) < version_tuple(minimum_version):
|
||||
raise SystemExit(1)
|
||||
for package in present:
|
||||
md.version(package)
|
||||
PY
|
||||
then
|
||||
echo "Firmware tab dependencies missing or stale; installing ESPHome firmware dependencies"
|
||||
${PIP} install \
|
||||
"fastapi==${FASTAPI_VERSION}" \
|
||||
"uvicorn[standard]==${UVICORN_VERSION}" \
|
||||
"python-multipart==${PY_MULTIPART_VERSION}" \
|
||||
"esphome==${ESPHOME_VERSION}"
|
||||
echo "UI dependencies missing or stale; installing recorder dependencies"
|
||||
install_ui_deps
|
||||
fi
|
||||
fi
|
||||
|
||||
# -----------------------------
|
||||
# Trainer server env
|
||||
# -----------------------------
|
||||
|
||||
1130
static/index.html
1130
static/index.html
File diff suppressed because it is too large
Load Diff
@@ -37,10 +37,11 @@ STATIC_DIR = Path(os.environ.get("STATIC_DIR", str(ROOT_DIR / "static"))).resolv
|
||||
PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve()
|
||||
CAPTURED_DIR = Path(os.environ.get("CAPTURED_DIR", str(DATA_DIR / "captured_audio"))).resolve()
|
||||
NEGATIVE_DIR = Path(os.environ.get("NEGATIVE_DIR", str(DATA_DIR / "negative_samples"))).resolve()
|
||||
TRIM_HISTORY_DIR = Path(os.environ.get("TRIM_HISTORY_DIR", str(DATA_DIR / "trim_history"))).resolve()
|
||||
TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TRAINED_WAKE_WORDS_DIR = Path(
|
||||
os.environ.get("TRAINED_WAKE_WORDS_DIR", str(DATA_DIR / "trained_wake_words"))
|
||||
).resolve()
|
||||
|
||||
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
|
||||
PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator"
|
||||
PIPER_VOICES_DIR = PIPER_ROOT / "voices"
|
||||
@@ -85,12 +86,17 @@ FIRMWARE_MAX_LOG_LINES = int(os.environ.get("FIRMWARE_MAX_LOG_LINES", "500"))
|
||||
FIRMWARE_GITHUB_OWNER = os.environ.get("FIRMWARE_GITHUB_OWNER", "TaterTotterson")
|
||||
FIRMWARE_GITHUB_REPO = os.environ.get("FIRMWARE_GITHUB_REPO", "microWakeWords")
|
||||
FIRMWARE_GITHUB_REF = os.environ.get("FIRMWARE_GITHUB_REF", "main")
|
||||
WAKE_SOUND_CATALOG_CACHE_TTL_SECONDS = int(os.environ.get("WAKE_SOUND_CATALOG_CACHE_TTL_SECONDS", "600"))
|
||||
FIRMWARE_PLATFORMIO_DIR = FIRMWARE_CACHE_DIR / "platformio"
|
||||
FIRMWARE_HOME_DIR = FIRMWARE_CACHE_DIR / "home"
|
||||
FIRMWARE_XDG_CACHE_DIR = FIRMWARE_CACHE_DIR / "cache"
|
||||
FIRMWARE_ESPHOME_DATA_DIR = FIRMWARE_CACHE_DIR / "esphome_data"
|
||||
FIRMWARE_PROFILE_FILE = Path(
|
||||
os.environ.get("FIRMWARE_PROFILE_FILE", str(FIRMWARE_CACHE_DIR / "profiles.json"))
|
||||
).resolve()
|
||||
WAKE_SOUND_MANIFEST_PATHS = ("wake_sound_manifest.json", "wake-sound-manifest.json")
|
||||
WAKE_SOUND_CATALOG_CACHE: Dict[str, Any] = {"ts": 0.0, "payload": {}}
|
||||
WAKE_SOUND_CATALOG_LOCK = threading.Lock()
|
||||
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
|
||||
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024)))
|
||||
|
||||
@@ -113,6 +119,33 @@ FIRMWARE_TEMPLATE_SPECS = (
|
||||
"fixed_keys": {"node_name"},
|
||||
"auto_keys": {"ha_voice_ip"},
|
||||
},
|
||||
{
|
||||
"key": "respeaker_lite",
|
||||
"label": "ReSpeaker Lite (respeakerLite-TaterTimer.yaml)",
|
||||
"path": "respeakerLite-TaterTimer.yaml",
|
||||
"identity_key": "device_name",
|
||||
"friendly_key": "friendly_name",
|
||||
"fixed_keys": {"device_name"},
|
||||
"auto_keys": {"ha_voice_ip"},
|
||||
},
|
||||
{
|
||||
"key": "koala",
|
||||
"label": "Koala Satellite (koala-TaterTimer.yaml)",
|
||||
"path": "koala-TaterTimer.yaml",
|
||||
"identity_key": "device_name",
|
||||
"friendly_key": "friendly_name",
|
||||
"fixed_keys": {"device_name"},
|
||||
"auto_keys": {"ha_voice_ip"},
|
||||
},
|
||||
{
|
||||
"key": "respeaker_xvf3800",
|
||||
"label": "ReSpeaker XVF3800 (respeakerXVF3800-TaterTimer.yaml)",
|
||||
"path": "respeakerXVF3800-TaterTimer.yaml",
|
||||
"identity_key": "device_name",
|
||||
"friendly_key": "friendly_name",
|
||||
"fixed_keys": {"device_name"},
|
||||
"auto_keys": {"ha_voice_ip"},
|
||||
},
|
||||
)
|
||||
|
||||
app = FastAPI(title="microWakeWord Personal Samples")
|
||||
@@ -164,11 +197,58 @@ FIRMWARE_LOCK = threading.Lock()
|
||||
FIRMWARE_SESSIONS: Dict[str, Dict[str, Any]] = {}
|
||||
ANSI_ESCAPE_RE = re.compile(r"\x1B(?:\[[0-?]*[ -/]*[@-~]|[@-Z\\-_])")
|
||||
|
||||
# --- Silero VAD (lazy-loaded) ---
|
||||
_silero_vad_model = None
|
||||
_silero_vad_utils = None
|
||||
_SILERO_VAD_LOCK = threading.Lock()
|
||||
VAD_SELECTION_PAD_START_S = 0.08
|
||||
VAD_SELECTION_PAD_END_S = 0.08
|
||||
|
||||
|
||||
def _load_silero_vad():
|
||||
"""Lazy-load Silero VAD model on first use. Returns (model, utils)."""
|
||||
global _silero_vad_model, _silero_vad_utils
|
||||
if _silero_vad_model is not None:
|
||||
return _silero_vad_model, _silero_vad_utils
|
||||
with _SILERO_VAD_LOCK:
|
||||
if _silero_vad_model is not None:
|
||||
return _silero_vad_model, _silero_vad_utils
|
||||
import torch
|
||||
import silero_vad
|
||||
model = silero_vad.load_silero_vad()
|
||||
model.eval()
|
||||
_silero_vad_model = model
|
||||
_silero_vad_utils = {"torch": torch}
|
||||
return model, _silero_vad_utils
|
||||
|
||||
|
||||
def _detect_speech_segments(wav_bytes: bytes) -> List[Dict[str, float]]:
|
||||
"""Run Silero VAD on 16 kHz mono WAV bytes. Return {start, end} seconds."""
|
||||
model, utils = _load_silero_vad()
|
||||
torch = utils["torch"]
|
||||
import numpy as np
|
||||
from silero_vad.utils_vad import get_speech_timestamps
|
||||
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
|
||||
raw = wf.readframes(wf.getnframes())
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
audio_tensor = torch.from_numpy(samples)
|
||||
|
||||
timestamps = get_speech_timestamps(
|
||||
audio_tensor,
|
||||
model,
|
||||
sampling_rate=16000,
|
||||
threshold=0.5,
|
||||
min_speech_duration_ms=150,
|
||||
min_silence_duration_ms=100,
|
||||
return_seconds=True,
|
||||
)
|
||||
return [{"start": round(ts["start"], 3), "end": round(ts["end"], 3)} for ts in timestamps]
|
||||
|
||||
|
||||
class _FirmwareYamlLoader(yaml.SafeLoader):
|
||||
pass
|
||||
|
||||
|
||||
class _FirmwareYamlDumper(yaml.SafeDumper):
|
||||
pass
|
||||
|
||||
@@ -1004,7 +1084,7 @@ def _list_captured_items() -> List[Dict[str, Any]]:
|
||||
def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]:
|
||||
meta = _load_sidecar_json(audio_path)
|
||||
stat = audio_path.stat()
|
||||
final_format = meta.get("final_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {}
|
||||
final_format = meta.get("final_format") or meta.get("detected_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {}
|
||||
return {
|
||||
"bucket": bucket,
|
||||
"saved_as": audio_path.name,
|
||||
@@ -1016,6 +1096,8 @@ def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]:
|
||||
"reviewed_at": meta.get("reviewed_at") or "",
|
||||
"created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
|
||||
"converted": bool(meta.get("converted")),
|
||||
"trimmed": bool(meta.get("trimmed")),
|
||||
"source_file": meta.get("source_file") or "",
|
||||
"final_format": final_format,
|
||||
"message": meta.get("message") or "",
|
||||
"size_bytes": stat.st_size,
|
||||
@@ -1031,9 +1113,10 @@ def _list_sample_items(directory: Path, bucket: str) -> List[Dict[str, Any]]:
|
||||
items.append(_sample_item_from_path(audio_path, bucket))
|
||||
except Exception:
|
||||
continue
|
||||
# Untrimmed first (stable sort preserves mtime order within each group).
|
||||
items.sort(key=lambda x: x.get("trimmed", False))
|
||||
return items
|
||||
|
||||
|
||||
def _samples_payload() -> Dict[str, Any]:
|
||||
takes = _sync_personal_samples_state()
|
||||
personal_items = _list_sample_items(PERSONAL_DIR, "personal")
|
||||
@@ -1449,6 +1532,93 @@ def _load_firmware_template_text(spec: Dict[str, Any]) -> tuple[str, str]:
|
||||
raise RuntimeError(f"Could not download firmware template from {url}: {exc}") from exc
|
||||
|
||||
|
||||
def _wake_sound_label_from_slug(slug: str) -> str:
|
||||
text = str(slug or "").strip()
|
||||
if not text:
|
||||
return "Wake Sound"
|
||||
return re.sub(r"[_\-.]+", " ", text).strip().title() or "Wake Sound"
|
||||
|
||||
|
||||
def _wake_sound_entries_from_manifest(payload: Any) -> List[Dict[str, str]]:
|
||||
rows: List[Any] = []
|
||||
if isinstance(payload, list):
|
||||
rows = list(payload)
|
||||
elif isinstance(payload, dict):
|
||||
for key in ("entries", "wake_sounds", "sounds", "audio", "items"):
|
||||
candidate = payload.get(key)
|
||||
if isinstance(candidate, list):
|
||||
rows = list(candidate)
|
||||
break
|
||||
|
||||
entries: List[Dict[str, str]] = []
|
||||
seen = set()
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
url = str(
|
||||
row.get("url")
|
||||
or row.get("download_url")
|
||||
or row.get("audio_url")
|
||||
or row.get("sound_url")
|
||||
or row.get("wake_sound_url")
|
||||
or row.get("wake_word_triggered_sound_file")
|
||||
or ""
|
||||
).strip()
|
||||
path = str(row.get("path") or "").strip()
|
||||
if not url and path:
|
||||
url = _firmware_raw_url(path)
|
||||
if not url or url in seen:
|
||||
continue
|
||||
seen.add(url)
|
||||
slug = str(row.get("slug") or row.get("name") or row.get("key") or Path(path or url).stem).strip()
|
||||
entries.append(
|
||||
{
|
||||
"value": url,
|
||||
"label": str(row.get("label") or row.get("title") or _wake_sound_label_from_slug(slug)).strip(),
|
||||
}
|
||||
)
|
||||
return sorted(entries, key=lambda item: (item["label"].lower(), item["value"]))
|
||||
|
||||
|
||||
def _load_wake_sound_catalog() -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
with WAKE_SOUND_CATALOG_LOCK:
|
||||
cached_ts = float(WAKE_SOUND_CATALOG_CACHE.get("ts") or 0.0)
|
||||
cached_payload = WAKE_SOUND_CATALOG_CACHE.get("payload")
|
||||
if isinstance(cached_payload, dict) and (now - cached_ts) < WAKE_SOUND_CATALOG_CACHE_TTL_SECONDS:
|
||||
return copy.deepcopy(cached_payload)
|
||||
|
||||
warnings: List[str] = []
|
||||
for manifest_path in WAKE_SOUND_MANIFEST_PATHS:
|
||||
manifest_url = _firmware_raw_url(manifest_path)
|
||||
try:
|
||||
payload = json.loads(_fetch_text_url(manifest_url, timeout=20))
|
||||
entries = _wake_sound_entries_from_manifest(payload)
|
||||
if entries:
|
||||
catalog = {"entries": entries, "warning": "", "source_label": manifest_url}
|
||||
with WAKE_SOUND_CATALOG_LOCK:
|
||||
WAKE_SOUND_CATALOG_CACHE["ts"] = now
|
||||
WAKE_SOUND_CATALOG_CACHE["payload"] = copy.deepcopy(catalog)
|
||||
return catalog
|
||||
except Exception as exc:
|
||||
warnings.append(f"{manifest_path}: {exc}")
|
||||
|
||||
catalog = {
|
||||
"entries": [],
|
||||
"warning": warnings[0] if warnings else "Wake sound catalog unavailable.",
|
||||
"source_label": "",
|
||||
}
|
||||
with WAKE_SOUND_CATALOG_LOCK:
|
||||
WAKE_SOUND_CATALOG_CACHE["ts"] = now
|
||||
WAKE_SOUND_CATALOG_CACHE["payload"] = copy.deepcopy(catalog)
|
||||
return catalog
|
||||
|
||||
|
||||
def _wake_sound_picker_options(catalog: Dict[str, Any]) -> List[Dict[str, str]]:
|
||||
entries = catalog.get("entries") if isinstance(catalog.get("entries"), list) else []
|
||||
return [{"value": "__custom__", "label": "Custom URL"}, *[dict(row) for row in entries if isinstance(row, dict)]]
|
||||
|
||||
|
||||
def _extract_substitution_sections(raw_text: str) -> Dict[str, str]:
|
||||
section_map: Dict[str, str] = {}
|
||||
in_substitutions = False
|
||||
@@ -1488,19 +1658,115 @@ def _load_firmware_profiles() -> Dict[str, Dict[str, str]]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_firmware_profile(template_key: str, values: Dict[str, str]) -> None:
|
||||
def _save_firmware_profile(profile_key: str, values: Dict[str, str]) -> None:
|
||||
FIRMWARE_PROFILE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
profiles = _load_firmware_profiles()
|
||||
profiles[template_key] = {str(key): str(value) for key, value in values.items() if str(key)}
|
||||
profiles[profile_key] = {str(key): str(value) for key, value in values.items() if str(key)}
|
||||
FIRMWARE_PROFILE_FILE.write_text(json.dumps(profiles, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
|
||||
def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]) -> Dict[str, str]:
|
||||
ctx = _load_firmware_template_context(template_key)
|
||||
def _firmware_profile_target(raw_host: Any = "", raw_port: Any = "") -> tuple[str, str]:
|
||||
host = str(raw_host or "").strip()
|
||||
port = str(raw_port or "").strip()
|
||||
if "://" in host:
|
||||
parsed = urlparse(host)
|
||||
host = parsed.hostname or ""
|
||||
if not port and parsed.port:
|
||||
port = str(parsed.port)
|
||||
host = host.strip().strip("/")
|
||||
if host.count(":") == 1 and not port:
|
||||
maybe_host, maybe_port = host.rsplit(":", 1)
|
||||
if maybe_port.isdigit():
|
||||
host = maybe_host
|
||||
port = maybe_port
|
||||
host = host.strip("[]").strip().lower()
|
||||
if not host:
|
||||
return "", ""
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_port = int(port or FIRMWARE_DEFAULT_OTA_PORT)
|
||||
if parsed_port == 6053:
|
||||
parsed_port = FIRMWARE_DEFAULT_OTA_PORT
|
||||
port = str(parsed_port)
|
||||
if not port:
|
||||
port = str(FIRMWARE_DEFAULT_OTA_PORT)
|
||||
return host, port
|
||||
|
||||
|
||||
def _firmware_profile_key_for_target(raw_host: Any = "", raw_port: Any = "") -> str:
|
||||
host, port = _firmware_profile_target(raw_host, raw_port)
|
||||
return f"device:{host}:{port}" if host else ""
|
||||
|
||||
|
||||
def _firmware_profile_key(template_key: Any = "", raw_host: Any = "", raw_port: Any = "") -> str:
|
||||
target_key = _firmware_profile_key_for_target(raw_host, raw_port)
|
||||
template = str(template_key or "").strip().lower()
|
||||
if target_key and template:
|
||||
return f"{target_key}:template:{template}"
|
||||
return target_key or template
|
||||
|
||||
|
||||
def _firmware_cache_slug(*parts: Any) -> str:
|
||||
raw = "__".join(str(part or "").strip() for part in parts if str(part or "").strip())
|
||||
slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", raw).strip("._-")
|
||||
return (slug[:96] or "default").lower()
|
||||
|
||||
|
||||
def _firmware_build_cache_path(
|
||||
template_key: str,
|
||||
normalized: Dict[str, str],
|
||||
host: str,
|
||||
port: Any = None,
|
||||
identity_key: str = "",
|
||||
friendly_key: str = "",
|
||||
) -> Path:
|
||||
normalized_host, normalized_port = _firmware_profile_target(host, port)
|
||||
template_slug = _firmware_cache_slug(template_key, "template")
|
||||
identity_key = str(identity_key or "").strip()
|
||||
friendly_key = str(friendly_key or "").strip()
|
||||
device_identity = (
|
||||
(normalized.get(identity_key) if identity_key else "")
|
||||
or (normalized.get(friendly_key) if friendly_key else "")
|
||||
or normalized.get("node_name")
|
||||
or normalized.get("device_name")
|
||||
or normalized.get("friendly_name")
|
||||
or normalized.get("name")
|
||||
or normalized_host
|
||||
or "device"
|
||||
)
|
||||
target_slug = _firmware_cache_slug(device_identity, normalized_host, normalized_port)
|
||||
return FIRMWARE_CACHE_DIR / "builds" / template_slug / target_slug
|
||||
|
||||
|
||||
def _load_firmware_profile(template_key: str, profile_key: str = "") -> Dict[str, str]:
|
||||
profiles = _load_firmware_profiles()
|
||||
profile = profiles.get(profile_key) if profile_key else None
|
||||
if isinstance(profile, dict):
|
||||
return dict(profile)
|
||||
if profile_key and ":template:" in profile_key:
|
||||
legacy_device_key = profile_key.split(":template:", 1)[0]
|
||||
legacy_device = profiles.get(legacy_device_key)
|
||||
if isinstance(legacy_device, dict):
|
||||
return dict(legacy_device)
|
||||
legacy = profiles.get(template_key)
|
||||
return dict(legacy) if isinstance(legacy, dict) else {}
|
||||
|
||||
|
||||
def _firmware_profile_values_for_template(profile: Dict[str, Any], substitutions: Dict[str, Any]) -> Dict[str, str]:
|
||||
keep_keys = {str(key or "").strip() for key in substitutions.keys()}
|
||||
keep_keys.update({"__target_host", "__target_port", "wake_sound_catalog", "wake_word_choice"})
|
||||
return {
|
||||
key: str(profile.get(key) or "")
|
||||
for key in keep_keys
|
||||
if key and key in profile
|
||||
}
|
||||
|
||||
|
||||
def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any], profile_key: str = "") -> Dict[str, str]:
|
||||
ctx = _load_firmware_template_context(template_key, profile_key)
|
||||
spec = ctx["spec"]
|
||||
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
|
||||
substitutions = ctx["substitutions"]
|
||||
normalized: Dict[str, str] = dict(profile)
|
||||
normalized = _firmware_profile_values_for_template(profile, substitutions)
|
||||
fixed_keys = set(spec.get("fixed_keys") or set())
|
||||
identity_key = str(spec.get("identity_key") or "").strip()
|
||||
if identity_key:
|
||||
@@ -1527,6 +1793,12 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
|
||||
elif key_text not in normalized:
|
||||
normalized[key_text] = "true" if str(default).lower() == "true" else "false"
|
||||
continue
|
||||
if key_text == "wake_word_model_url":
|
||||
if key_text in values:
|
||||
normalized[key_text] = _local_trained_wake_word_url(values.get(key_text))
|
||||
elif key_text not in normalized:
|
||||
normalized[key_text] = ""
|
||||
continue
|
||||
if key_text in values:
|
||||
normalized[key_text] = str(values.get(key_text) or "").strip()
|
||||
elif key_text not in normalized:
|
||||
@@ -1535,6 +1807,11 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
|
||||
wake_word_choice = str(values.get("wake_word_choice") or "").strip()
|
||||
if wake_word_choice:
|
||||
normalized["wake_word_choice"] = wake_word_choice
|
||||
wake_sound_choice = str(values.get("wake_sound_catalog") or "").strip()
|
||||
if wake_sound_choice:
|
||||
normalized["wake_sound_catalog"] = wake_sound_choice
|
||||
if wake_sound_choice != "__custom__" and "wake_word_triggered_sound_file" in substitutions:
|
||||
normalized["wake_word_triggered_sound_file"] = wake_sound_choice
|
||||
|
||||
target_host = str(values.get("__target_host") or "").strip()
|
||||
target_port = str(values.get("__target_port") or "").strip()
|
||||
@@ -1548,7 +1825,38 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
|
||||
return normalized
|
||||
|
||||
|
||||
def _load_firmware_template_context(template_key: str) -> Dict[str, Any]:
|
||||
def _local_trained_wake_word_url(value: Any) -> str:
|
||||
url = str(value or "").strip()
|
||||
return url if "/api/trained_wake_words/" in url else ""
|
||||
|
||||
|
||||
def _selected_trained_wake_word(
|
||||
trained_wake_words: List[Dict[str, Any]],
|
||||
profile: Dict[str, Any],
|
||||
substitutions: Dict[str, Any],
|
||||
) -> Dict[str, Any] | None:
|
||||
if not trained_wake_words:
|
||||
return None
|
||||
|
||||
saved_choice = str(profile.get("wake_word_choice") or "").strip()
|
||||
current_wake_word_name = str(
|
||||
profile.get("wake_word_name") or _template_default_string(substitutions.get("wake_word_name"))
|
||||
).strip()
|
||||
current_wake_word_url = str(profile.get("wake_word_model_url") or "").strip()
|
||||
|
||||
def match(predicate):
|
||||
return next((row for row in trained_wake_words if predicate(row)), None)
|
||||
|
||||
return (
|
||||
match(lambda row: str(row.get("key") or "") == saved_choice)
|
||||
or match(lambda row: str(row.get("json_url") or "") == current_wake_word_url)
|
||||
or match(lambda row: str(row.get("model_url") or "") == current_wake_word_url)
|
||||
or match(lambda row: str(row.get("wake_word_name") or "") == current_wake_word_name)
|
||||
or trained_wake_words[0]
|
||||
)
|
||||
|
||||
|
||||
def _load_firmware_template_context(template_key: str, profile_key: str = "") -> Dict[str, Any]:
|
||||
spec = _firmware_template_spec(template_key)
|
||||
raw_text, source_label = _load_firmware_template_text(spec)
|
||||
parsed = yaml.load(raw_text, Loader=_FirmwareYamlLoader)
|
||||
@@ -1564,12 +1872,12 @@ def _load_firmware_template_context(template_key: str) -> Dict[str, Any]:
|
||||
"template_doc": parsed,
|
||||
"substitutions": dict(substitutions),
|
||||
"sections": _extract_substitution_sections(raw_text),
|
||||
"profile": _load_firmware_profiles().get(str(spec.get("key") or ""), {}),
|
||||
"profile": _load_firmware_profile(str(spec.get("key") or ""), profile_key),
|
||||
}
|
||||
|
||||
|
||||
def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dict[str, Any]]:
|
||||
ctx = _load_firmware_template_context(template_key)
|
||||
def _firmware_template_fields(template_key: str, base_url: str = "", profile_key: str = "") -> List[Dict[str, Any]]:
|
||||
ctx = _load_firmware_template_context(template_key, profile_key)
|
||||
spec = ctx["spec"]
|
||||
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
|
||||
fields: List[Dict[str, Any]] = []
|
||||
@@ -1579,18 +1887,11 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
|
||||
fixed_keys.add(identity_key)
|
||||
hidden_keys = {"ha_voice_ip"} | set(spec.get("auto_keys") or set())
|
||||
trained_wake_words = _list_trained_wake_words(base_url)
|
||||
current_wake_word_name = str(
|
||||
profile.get("wake_word_name") or _template_default_string(ctx["substitutions"].get("wake_word_name"))
|
||||
).strip()
|
||||
saved_wake_word_choice = str(profile.get("wake_word_choice") or "").strip()
|
||||
selected_wake_word = next(
|
||||
(row["key"] for row in trained_wake_words if row.get("key") == saved_wake_word_choice),
|
||||
"",
|
||||
) or next(
|
||||
(row["key"] for row in trained_wake_words if row.get("wake_word_name") == current_wake_word_name),
|
||||
"",
|
||||
)
|
||||
wake_sound_catalog = _load_wake_sound_catalog()
|
||||
selected_wake_word_row = _selected_trained_wake_word(trained_wake_words, profile, ctx["substitutions"])
|
||||
selected_wake_word = str(selected_wake_word_row.get("key") or "") if selected_wake_word_row else ""
|
||||
wake_picker_added = False
|
||||
wake_sound_picker_added = False
|
||||
|
||||
for key, raw_value in ctx["substitutions"].items():
|
||||
key_text = str(key or "").strip()
|
||||
@@ -1616,6 +1917,35 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
|
||||
)
|
||||
wake_picker_added = True
|
||||
|
||||
if key_text == "wake_word_triggered_sound_file" and not wake_sound_picker_added:
|
||||
wake_sound_entries = wake_sound_catalog.get("entries") if isinstance(wake_sound_catalog.get("entries"), list) else []
|
||||
current_sound_url = str(profile.get(key_text) or _template_default_string(raw_value) or "").strip()
|
||||
saved_sound_choice = str(profile.get("wake_sound_catalog") or "").strip()
|
||||
available_sound_urls = {str(row.get("value") or "") for row in wake_sound_entries if isinstance(row, dict)}
|
||||
if saved_sound_choice in available_sound_urls or saved_sound_choice == "__custom__":
|
||||
picker_value = saved_sound_choice
|
||||
else:
|
||||
picker_value = current_sound_url if current_sound_url in available_sound_urls else "__custom__"
|
||||
description = (
|
||||
f"Choose from {len(wake_sound_entries)} prebuilt wake sounds, or leave this on Custom URL and paste your own audio URL below."
|
||||
if wake_sound_entries
|
||||
else "Prebuilt wake-sound catalog is unavailable right now. You can still paste any custom audio URL below."
|
||||
)
|
||||
if wake_sound_catalog.get("warning") and not wake_sound_entries:
|
||||
description = f"{description} {wake_sound_catalog['warning']}".strip()
|
||||
fields.append(
|
||||
{
|
||||
"key": "wake_sound_catalog",
|
||||
"label": "Prebuilt Wake Sound",
|
||||
"type": "wake_sound_select",
|
||||
"value": picker_value,
|
||||
"options": _wake_sound_picker_options(wake_sound_catalog),
|
||||
"description": description,
|
||||
"section": "Wake Sound",
|
||||
}
|
||||
)
|
||||
wake_sound_picker_added = True
|
||||
|
||||
default = _template_default_string(raw_value)
|
||||
saved = str(profile.get(key_text) or "")
|
||||
field_type = "text"
|
||||
@@ -1640,11 +1970,20 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
|
||||
placeholder = "Your Wi-Fi SSID"
|
||||
description = "Required before build + flash."
|
||||
elif key_text == "wake_word_model_url":
|
||||
placeholder = "https://.../wake_word.json"
|
||||
value = str(selected_wake_word_row.get("json_url") or "") if selected_wake_word_row else ""
|
||||
placeholder = "Train or select a local wake word first"
|
||||
description = "Filled from the local trained wake-word picker."
|
||||
elif key_text == "wake_word_name":
|
||||
if selected_wake_word_row:
|
||||
value = str(selected_wake_word_row.get("wake_word_name") or selected_wake_word_row.get("key") or "")
|
||||
placeholder = "hey_tater"
|
||||
elif key_text == "wake_word_triggered_sound_file":
|
||||
placeholder = "https://.../wake-sound.mp3"
|
||||
description = "Pick a prebuilt wake sound above or paste any custom audio URL."
|
||||
section = ctx["sections"].get(key_text) or "Firmware"
|
||||
if key_text in {"wake_word_name", "wake_word_model_url"}:
|
||||
if key_text == "wake_word_triggered_sound_file":
|
||||
section = "Wake Sound"
|
||||
elif key_text in {"wake_word_name", "wake_word_model_url"}:
|
||||
section = "Micro Wake Word"
|
||||
elif key_text.endswith("_sound_file"):
|
||||
section = "Sounds"
|
||||
@@ -1678,12 +2017,19 @@ def _esphome_pythonpath() -> str:
|
||||
return os.pathsep.join(paths)
|
||||
|
||||
|
||||
def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str, session_id: str) -> tuple[Path, Dict[str, str]]:
|
||||
ctx = _load_firmware_template_context(template_key)
|
||||
def _render_firmware_config(
|
||||
template_key: str,
|
||||
values: Dict[str, Any],
|
||||
host: str,
|
||||
session_id: str,
|
||||
port: Any = None,
|
||||
) -> tuple[Path, Dict[str, str], Path]:
|
||||
profile_key = _firmware_profile_key(template_key, host, port)
|
||||
ctx = _load_firmware_template_context(template_key, profile_key)
|
||||
spec = ctx["spec"]
|
||||
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
|
||||
substitutions = ctx["substitutions"]
|
||||
normalized: Dict[str, str] = dict(profile)
|
||||
normalized = _firmware_profile_values_for_template(profile, substitutions)
|
||||
fixed_keys = set(spec.get("fixed_keys") or set())
|
||||
identity_key = str(spec.get("identity_key") or "").strip()
|
||||
if identity_key:
|
||||
@@ -1702,10 +2048,15 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
|
||||
normalized[key_text] = "true" if bool(raw_value) else "false"
|
||||
elif key_text == "ha_voice_ip":
|
||||
normalized[key_text] = host
|
||||
elif key_text == "wake_word_model_url":
|
||||
normalized[key_text] = _local_trained_wake_word_url(raw_value)
|
||||
else:
|
||||
normalized[key_text] = str(raw_value if raw_value is not None else "").strip() or _template_default_string(
|
||||
substitutions.get(key_text)
|
||||
)
|
||||
wake_sound_choice = str(values.get("wake_sound_catalog") or "").strip()
|
||||
if wake_sound_choice and wake_sound_choice != "__custom__" and "wake_word_triggered_sound_file" in substitutions:
|
||||
normalized["wake_word_triggered_sound_file"] = wake_sound_choice
|
||||
|
||||
missing = []
|
||||
if not normalized.get("wifi_ssid"):
|
||||
@@ -1714,22 +2065,36 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
|
||||
missing.append("Wi-Fi password")
|
||||
if not host:
|
||||
missing.append("device IP or hostname")
|
||||
if "wake_word_model_url" in substitutions and not normalized.get("wake_word_model_url"):
|
||||
missing.append("local trained wake word")
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required firmware values: {', '.join(missing)}.")
|
||||
|
||||
config = copy.deepcopy(ctx["template_doc"])
|
||||
config["substitutions"] = {key: str(normalized.get(key, "")) for key in substitutions.keys()}
|
||||
build_path = _firmware_build_cache_path(
|
||||
str(spec.get("key") or template_key),
|
||||
normalized,
|
||||
host,
|
||||
port,
|
||||
str(spec.get("identity_key") or ""),
|
||||
str(spec.get("friendly_key") or ""),
|
||||
)
|
||||
esphome_block = config.get("esphome") if isinstance(config.get("esphome"), dict) else None
|
||||
if isinstance(esphome_block, dict):
|
||||
esphome_block["build_path"] = str(FIRMWARE_CACHE_DIR / "builds" / session_id / str(spec.get("key") or template_key))
|
||||
esphome_block["build_path"] = str(build_path)
|
||||
config["esphome"] = esphome_block
|
||||
|
||||
session_dir = FIRMWARE_CACHE_DIR / "configs" / session_id
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
config_path = session_dir / f"{str(spec.get('key') or template_key)}.yaml"
|
||||
config_path = session_dir / f"{build_path.parent.name}__{build_path.name}.yaml"
|
||||
config_path.write_text(yaml.dump(config, Dumper=_FirmwareYamlDumper, sort_keys=False, allow_unicode=True), encoding="utf-8")
|
||||
_save_firmware_profile(str(spec.get("key") or template_key), normalized)
|
||||
return config_path, normalized
|
||||
normalized_host, normalized_port = _firmware_profile_target(host, port)
|
||||
if normalized_host:
|
||||
normalized["__target_host"] = normalized_host
|
||||
normalized["__target_port"] = normalized_port
|
||||
_save_firmware_profile(profile_key or str(spec.get("key") or template_key), normalized)
|
||||
return config_path, normalized, build_path
|
||||
|
||||
|
||||
def _firmware_session_payload(session_id: str) -> Dict[str, Any]:
|
||||
@@ -1779,11 +2144,13 @@ def _firmware_runner_env(*, include_esphome_pythonpath: bool = False) -> Dict[st
|
||||
FIRMWARE_HOME_DIR.mkdir(parents=True, exist_ok=True)
|
||||
FIRMWARE_XDG_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
FIRMWARE_PLATFORMIO_DIR.mkdir(parents=True, exist_ok=True)
|
||||
FIRMWARE_ESPHOME_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
env = os.environ.copy()
|
||||
env.pop("PYTHONPATH", None)
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
env["HOME"] = str(FIRMWARE_HOME_DIR)
|
||||
env["XDG_CACHE_HOME"] = str(FIRMWARE_XDG_CACHE_DIR)
|
||||
env["ESPHOME_DATA_DIR"] = str(FIRMWARE_ESPHOME_DATA_DIR)
|
||||
env["PLATFORMIO_CORE_DIR"] = str(FIRMWARE_PLATFORMIO_DIR)
|
||||
env["PLATFORMIO_CACHE_DIR"] = str(FIRMWARE_PLATFORMIO_DIR / "cache")
|
||||
if include_esphome_pythonpath:
|
||||
@@ -1947,7 +2314,7 @@ def _run_firmware_build_flash_background(session_id: str):
|
||||
return
|
||||
|
||||
try:
|
||||
config_path, _normalized = _render_firmware_config(template_key, values, host, session_id)
|
||||
config_path, normalized, build_path = _render_firmware_config(template_key, values, host, session_id, port)
|
||||
except Exception as exc:
|
||||
_append_firmware_log(session_id, f"✗ Failed to prepare firmware config: {exc}")
|
||||
with FIRMWARE_LOCK:
|
||||
@@ -1974,6 +2341,9 @@ def _run_firmware_build_flash_background(session_id: str):
|
||||
_append_firmware_log(session_id, f"→ Template: {template_key}")
|
||||
_append_firmware_log(session_id, f"→ Device: {host}:{port}")
|
||||
_append_firmware_log(session_id, f"→ Config: {config_path}")
|
||||
_append_firmware_log(session_id, f"→ Build cache: {build_path}")
|
||||
if normalized.get("wake_word_triggered_sound_file"):
|
||||
_append_firmware_log(session_id, f"→ Wake sound: {normalized['wake_word_triggered_sound_file']}")
|
||||
_append_firmware_log(session_id, "→ Running: " + " ".join(command))
|
||||
|
||||
try:
|
||||
@@ -2012,6 +2382,12 @@ def _run_firmware_build_flash_background(session_id: str):
|
||||
"Tip: PlatformIO's ESP-IDF Python environment crashed while installing dependencies. "
|
||||
"Run Clean Build Files once, then retry the flash.",
|
||||
)
|
||||
if "pioarduino/registry" in joined_lines and "ninja-" in joined_lines and "status code '502'" in joined_lines:
|
||||
_append_firmware_log(
|
||||
session_id,
|
||||
"Tip: GitHub returned a 502 while PlatformIO was downloading Ninja. "
|
||||
"This is an upstream package download failure; retry the build in a few minutes.",
|
||||
)
|
||||
_append_firmware_log(session_id, f"✗ Firmware build + flash failed (exit_code={rc})")
|
||||
with FIRMWARE_LOCK:
|
||||
live = FIRMWARE_SESSIONS.get(session_id)
|
||||
@@ -2470,7 +2846,143 @@ def delete_sample(bucket: str, file_name: str):
|
||||
_remove_audio_with_sidecar(path)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
return _samples_payload()
|
||||
return {"ok": True, "deleted_bucket": bucket, "deleted_file": file_name, "message": f"Deleted {file_name}"}
|
||||
|
||||
|
||||
@app.post("/api/samples/{bucket}/{file_name}/vad")
|
||||
def vad_segments(bucket: str, file_name: str):
|
||||
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
|
||||
directory = bucket_map.get(bucket)
|
||||
if directory is None:
|
||||
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
|
||||
try:
|
||||
path = _resolve_audio_path(directory, file_name)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
|
||||
wav_bytes = path.read_bytes()
|
||||
try:
|
||||
all_segments = _detect_speech_segments(wav_bytes)
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": f"VAD failed: {str(e)}"}, status_code=500)
|
||||
|
||||
# Only return the first segment longer than 250 ms. Add deterministic
|
||||
# padding so VAD guides trimming without clipping quiet wake-word edges.
|
||||
filtered = [s for s in all_segments if (s["end"] - s["start"]) >= 0.25]
|
||||
if not filtered:
|
||||
return {"ok": True, "file_name": file_name, "segments": [], "segment_count": 0}
|
||||
seg = filtered[0]
|
||||
info = _inspect_wav_bytes(wav_bytes) or {}
|
||||
duration_s = float(info.get("duration_s") or 0.0)
|
||||
start = max(0.0, round(seg["start"] - VAD_SELECTION_PAD_START_S, 3))
|
||||
end = round(seg["end"] + VAD_SELECTION_PAD_END_S, 3)
|
||||
if duration_s > 0:
|
||||
end = min(duration_s, end)
|
||||
if end <= start:
|
||||
end = start + 0.001
|
||||
segment = {"start": start, "end": end}
|
||||
return {"ok": True, "file_name": file_name, "segments": [segment], "segment_count": 1}
|
||||
|
||||
|
||||
@app.post("/api/samples/trim")
|
||||
async def trim_sample_upload(
|
||||
file: UploadFile = File(...),
|
||||
bucket: str = Form(...),
|
||||
source_file: str = Form(...),
|
||||
start_time: str | None = Form(None),
|
||||
end_time: str | None = Form(None),
|
||||
):
|
||||
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
|
||||
directory = bucket_map.get(bucket)
|
||||
if directory is None:
|
||||
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
|
||||
|
||||
data = await file.read()
|
||||
if not data:
|
||||
return JSONResponse({"ok": False, "error": "Empty audio file."}, status_code=400)
|
||||
|
||||
info = _inspect_wav_bytes(data)
|
||||
if not info:
|
||||
try:
|
||||
data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav")
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400)
|
||||
elif not _is_target_wav(info):
|
||||
try:
|
||||
data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav")
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400)
|
||||
|
||||
try:
|
||||
orig_path = _resolve_audio_path(directory, source_file)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
|
||||
TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%f")
|
||||
backup_name = f"{ts}_{source_file}"
|
||||
backup_path = TRIM_HISTORY_DIR / backup_name
|
||||
shutil.copy2(orig_path, backup_path)
|
||||
|
||||
orig_sidecar = _audio_sidecar_path(orig_path)
|
||||
if orig_sidecar.exists():
|
||||
shutil.copy2(orig_sidecar, _audio_sidecar_path(backup_path))
|
||||
|
||||
orig_path.write_bytes(data)
|
||||
|
||||
old_sidecar = _load_sidecar_json(orig_path)
|
||||
sidecar = {
|
||||
**old_sidecar,
|
||||
"trimmed": True,
|
||||
"source_file": source_file,
|
||||
"source_bucket": bucket,
|
||||
"trim_start_s": float(start_time) if start_time else None,
|
||||
"trim_end_s": float(end_time) if end_time else None,
|
||||
"undo_backup_file": backup_name,
|
||||
}
|
||||
_write_sidecar_json(orig_path, sidecar)
|
||||
|
||||
updated_item = _sample_item_from_path(orig_path, bucket)
|
||||
updated_item["trimmed"] = True
|
||||
updated_item["source_file"] = source_file
|
||||
return {"ok": True, "updated_sample": updated_item, "message": f"Trimmed {source_file}"}
|
||||
|
||||
|
||||
@app.post("/api/samples/revert")
|
||||
def revert_trim(
|
||||
bucket: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
):
|
||||
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
|
||||
directory = bucket_map.get(bucket)
|
||||
if directory is None:
|
||||
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
|
||||
|
||||
try:
|
||||
file_path = _resolve_audio_path(directory, file_name)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
|
||||
sidecar = _load_sidecar_json(file_path)
|
||||
backup_name = sidecar.get("undo_backup_file")
|
||||
if not backup_name:
|
||||
return JSONResponse({"ok": False, "error": "No trim backup found for this sample."}, status_code=400)
|
||||
|
||||
backup_path = TRIM_HISTORY_DIR / backup_name
|
||||
if not backup_path.exists():
|
||||
return JSONResponse({"ok": False, "error": "Trim backup file missing."}, status_code=404)
|
||||
|
||||
shutil.copy2(backup_path, file_path)
|
||||
backup_sidecar = _audio_sidecar_path(backup_path)
|
||||
if backup_sidecar.exists():
|
||||
shutil.copy2(backup_sidecar, _audio_sidecar_path(file_path))
|
||||
|
||||
backup_path.unlink()
|
||||
if backup_sidecar.exists():
|
||||
backup_sidecar.unlink()
|
||||
|
||||
updated_item = _sample_item_from_path(file_path, bucket)
|
||||
return {"ok": True, "updated_sample": updated_item, "message": f"Reverted {file_name}"}
|
||||
|
||||
|
||||
@app.post("/api/captured_audio/{file_name}/approve_personal")
|
||||
@@ -2512,28 +3024,30 @@ def firmware_devices():
|
||||
|
||||
|
||||
@app.get("/api/firmware/templates")
|
||||
def firmware_templates(request: Request):
|
||||
def firmware_templates(request: Request, target_host: str = "", target_port: str = ""):
|
||||
templates = []
|
||||
warnings = []
|
||||
base_url = _request_base_url(request)
|
||||
wake_words = _list_trained_wake_words(base_url)
|
||||
profiles = _load_firmware_profiles()
|
||||
selected_host, selected_port = _firmware_profile_target(target_host, target_port)
|
||||
for spec in FIRMWARE_TEMPLATE_SPECS:
|
||||
key = str(spec.get("key") or "")
|
||||
profile = profiles.get(key, {})
|
||||
target_port = str(profile.get("__target_port") or "")
|
||||
if target_port == "6053":
|
||||
target_port = str(FIRMWARE_DEFAULT_OTA_PORT)
|
||||
profile_key = _firmware_profile_key(key, target_host, target_port)
|
||||
profile = _load_firmware_profile(key, profile_key)
|
||||
row_target_host = selected_host or str(profile.get("__target_host") or "")
|
||||
row_target_port = selected_port or str(profile.get("__target_port") or "")
|
||||
if row_target_port == "6053":
|
||||
row_target_port = str(FIRMWARE_DEFAULT_OTA_PORT)
|
||||
row = {
|
||||
"value": key,
|
||||
"label": str(spec.get("label") or key),
|
||||
"source_url": _firmware_raw_url(str(spec.get("path") or "")),
|
||||
"target_host": str(profile.get("__target_host") or ""),
|
||||
"target_port": target_port,
|
||||
"target_host": row_target_host,
|
||||
"target_port": row_target_port,
|
||||
"fields": [],
|
||||
}
|
||||
try:
|
||||
row["fields"] = _firmware_template_fields(key, base_url)
|
||||
row["fields"] = _firmware_template_fields(key, base_url, profile_key)
|
||||
except Exception as exc:
|
||||
warnings.append(f"{row['label']}: {exc}")
|
||||
templates.append(row)
|
||||
@@ -2553,11 +3067,12 @@ def firmware_profile(payload: Dict[str, Any]):
|
||||
template_key = str(body.get("template_key") or "").strip()
|
||||
_firmware_template_spec(template_key)
|
||||
values = body.get("values") if isinstance(body.get("values"), dict) else {}
|
||||
saved = _normalize_firmware_profile_update(template_key, values)
|
||||
_save_firmware_profile(template_key, saved)
|
||||
profile_key = _firmware_profile_key(template_key, values.get("__target_host"), values.get("__target_port"))
|
||||
saved = _normalize_firmware_profile_update(template_key, values, profile_key)
|
||||
_save_firmware_profile(profile_key or template_key, saved)
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
|
||||
return {"ok": True, "template_key": template_key, "saved_fields": sorted(saved.keys())}
|
||||
return {"ok": True, "template_key": template_key, "profile_key": profile_key or template_key, "saved_fields": sorted(saved.keys())}
|
||||
|
||||
|
||||
@app.get("/api/trained_wake_words/catalog")
|
||||
@@ -2625,7 +3140,7 @@ def firmware_clean():
|
||||
return JSONResponse({"ok": False, "error": f"Wait for active firmware session(s) to finish: {', '.join(active[:3])}."}, status_code=400)
|
||||
|
||||
removed = []
|
||||
for child in ("configs", "builds", "platformio", "home", "cache"):
|
||||
for child in ("configs", "builds", "platformio", "home", "cache", "esphome_data"):
|
||||
path = FIRMWARE_CACHE_DIR / child
|
||||
if path.exists():
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
|
||||
Reference in New Issue
Block a user