diff --git a/static/index.html b/static/index.html index 86244bd..990b8b8 100644 --- a/static/index.html +++ b/static/index.html @@ -1037,9 +1037,9 @@
+ -
@@ -1418,6 +1418,7 @@ activeView: "trainer", }; let firmwareProfileSaveTimer = null; + let firmwareProfileReloadTimer = null; function setPill(el, text, cls) { el.className = "pill " + (cls || ""); @@ -1838,7 +1839,10 @@ setPill($("firmwareDetectStatus"), `${list.length} detected`, "ok"); if (!($("firmwareHost").value || "").trim()) { $("firmwareDeviceSelect").value = "0"; - applySelectedFirmwareDevice(); + applySelectedFirmwareDevice().catch((error) => { + setPill($("firmwareStatus"), "Device settings failed", "warn"); + console.warn("Device settings load failed", error); + }); } } @@ -1850,15 +1854,20 @@ return data; } - function applySelectedFirmwareDevice() { + async function applySelectedFirmwareDevice() { const indexText = $("firmwareDeviceSelect").value; if (indexText === "") return; const device = uiState.firmware.devices[Number(indexText)]; if (!device) return; + await flushFirmwareProfileSave(); $("firmwareHost").value = device.host || ""; $("firmwarePort").value = device.port || 3232; + setPill($("firmwareStatus"), "Loading device settings...", "warn"); + const data = await refreshFirmwareTemplates(); + if (!Array.isArray(data?.warnings) || !data.warnings.length) { + setPill($("firmwareStatus"), "Device settings loaded", "ok"); + } syncButtons(); - scheduleFirmwareProfileSave(); } function selectedFirmwareTemplate() { @@ -1922,6 +1931,7 @@
`).join(""); + syncRenderedWakeWordSelection(); } function renderFirmwareField(field) { @@ -1987,7 +1997,14 @@ `; } - function applyWakeWordSelection(select) { + function syncRenderedWakeWordSelection() { + const select = document.querySelector("select[data-wake-word-select]"); + if (select && select.value) { + applyWakeWordSelection(select, { silent: true }); + } + } + + function applyWakeWordSelection(select, options = {}) { const option = select?.selectedOptions?.[0]; if (!option || !select.value) return; const wakeWordName = option.dataset.wakeWordName || ""; @@ -2000,13 +2017,25 @@ if (urlInput && wakeWordUrl) { urlInput.value = wakeWordUrl; } - setPill($("firmwareStatus"), "Wake word selected", "ok"); - scheduleFirmwareProfileSave(); + if (!options.silent) { + setPill($("firmwareStatus"), "Wake word selected", "ok"); + scheduleFirmwareProfileSave(); + } syncButtons(); } + function firmwareTemplateQuery() { + const params = new URLSearchParams(); + const host = ($("firmwareHost").value || "").trim(); + const port = ($("firmwarePort").value || "3232").trim(); + if (host) params.set("target_host", host); + if (port) params.set("target_port", port); + const query = params.toString(); + return query ? `?${query}` : ""; + } + async function refreshFirmwareTemplates() { - const data = await api("/api/firmware/templates", { method: "GET" }); + const data = await api(`/api/firmware/templates${firmwareTemplateQuery()}`, { method: "GET" }); renderFirmwareTemplates(data); if (Array.isArray(data.warnings) && data.warnings.length) { setPill($("firmwareStatus"), "Template warning", "warn"); @@ -2062,11 +2091,33 @@ }, 550); } + function scheduleFirmwareProfileReload() { + if (firmwareProfileSaveTimer) { + clearTimeout(firmwareProfileSaveTimer); + firmwareProfileSaveTimer = null; + } + if (firmwareProfileReloadTimer) { + clearTimeout(firmwareProfileReloadTimer); + } + firmwareProfileReloadTimer = setTimeout(async () => { + firmwareProfileReloadTimer = null; + try { + await refreshFirmwareTemplates(); + } catch (error) { + setPill($("firmwareStatus"), "Device settings failed", "warn"); + } + }, 650); + } + function flushFirmwareProfileSave(templateKey = null) { if (firmwareProfileSaveTimer) { clearTimeout(firmwareProfileSaveTimer); firmwareProfileSaveTimer = null; } + if (firmwareProfileReloadTimer) { + clearTimeout(firmwareProfileReloadTimer); + firmwareProfileReloadTimer = null; + } return saveFirmwareProfileNow({ templateKey, quiet: true }).catch(() => null); } @@ -2500,11 +2551,11 @@ $("firmwareHost").addEventListener("input", () => { syncButtons(); - scheduleFirmwareProfileSave(); + scheduleFirmwareProfileReload(); }); $("firmwarePort").addEventListener("input", () => { syncButtons(); - scheduleFirmwareProfileSave(); + scheduleFirmwareProfileReload(); }); $("firmwareTemplate").addEventListener("change", () => { flushFirmwareProfileSave(uiState.firmware.activeTemplateKey); @@ -2526,7 +2577,12 @@ syncButtons(); scheduleFirmwareProfileSave(); }); - $("firmwareDeviceSelect").addEventListener("change", applySelectedFirmwareDevice); + $("firmwareDeviceSelect").addEventListener("change", () => { + applySelectedFirmwareDevice().catch((error) => { + setPill($("firmwareStatus"), "Device settings failed", "warn"); + alert("Device settings failed: " + error.message); + }); + }); $("refreshFirmwareBtn").addEventListener("click", async () => { try { await refreshFirmwareDevices(); diff --git a/trainer_server.py b/trainer_server.py index ffc6dc0..88d3a8f 100644 --- a/trainer_server.py +++ b/trainer_server.py @@ -1488,15 +1488,56 @@ def _load_firmware_profiles() -> Dict[str, Dict[str, str]]: return {} -def _save_firmware_profile(template_key: str, values: Dict[str, str]) -> None: +def _save_firmware_profile(profile_key: str, values: Dict[str, str]) -> None: FIRMWARE_PROFILE_FILE.parent.mkdir(parents=True, exist_ok=True) profiles = _load_firmware_profiles() - profiles[template_key] = {str(key): str(value) for key, value in values.items() if str(key)} + profiles[profile_key] = {str(key): str(value) for key, value in values.items() if str(key)} FIRMWARE_PROFILE_FILE.write_text(json.dumps(profiles, indent=2, sort_keys=True), encoding="utf-8") -def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]) -> Dict[str, str]: - ctx = _load_firmware_template_context(template_key) +def _firmware_profile_target(raw_host: Any = "", raw_port: Any = "") -> tuple[str, str]: + host = str(raw_host or "").strip() + port = str(raw_port or "").strip() + if "://" in host: + parsed = urlparse(host) + host = parsed.hostname or "" + if not port and parsed.port: + port = str(parsed.port) + host = host.strip().strip("/") + if host.count(":") == 1 and not port: + maybe_host, maybe_port = host.rsplit(":", 1) + if maybe_port.isdigit(): + host = maybe_host + port = maybe_port + host = host.strip("[]").strip().lower() + if not host: + return "", "" + with contextlib.suppress(Exception): + parsed_port = int(port or FIRMWARE_DEFAULT_OTA_PORT) + if parsed_port == 6053: + parsed_port = FIRMWARE_DEFAULT_OTA_PORT + port = str(parsed_port) + if not port: + port = str(FIRMWARE_DEFAULT_OTA_PORT) + return host, port + + +def _firmware_profile_key_for_target(raw_host: Any = "", raw_port: Any = "") -> str: + host, port = _firmware_profile_target(raw_host, raw_port) + return f"device:{host}:{port}" if host else "" + + +def _load_firmware_profile(template_key: str, profile_key: str = "") -> Dict[str, str]: + profiles = _load_firmware_profiles() + profile = profiles.get(profile_key) if profile_key else None + if isinstance(profile, dict): + return dict(profile) + legacy = profiles.get(template_key) + return dict(legacy) if isinstance(legacy, dict) else {} + + +def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any], profile_key: str = "") -> Dict[str, str]: + ctx = _load_firmware_template_context(template_key, profile_key) spec = ctx["spec"] profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} substitutions = ctx["substitutions"] @@ -1527,6 +1568,12 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any] elif key_text not in normalized: normalized[key_text] = "true" if str(default).lower() == "true" else "false" continue + if key_text == "wake_word_model_url": + if key_text in values: + normalized[key_text] = _local_trained_wake_word_url(values.get(key_text)) + elif key_text not in normalized: + normalized[key_text] = "" + continue if key_text in values: normalized[key_text] = str(values.get(key_text) or "").strip() elif key_text not in normalized: @@ -1548,7 +1595,38 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any] return normalized -def _load_firmware_template_context(template_key: str) -> Dict[str, Any]: +def _local_trained_wake_word_url(value: Any) -> str: + url = str(value or "").strip() + return url if "/api/trained_wake_words/" in url else "" + + +def _selected_trained_wake_word( + trained_wake_words: List[Dict[str, Any]], + profile: Dict[str, Any], + substitutions: Dict[str, Any], +) -> Dict[str, Any] | None: + if not trained_wake_words: + return None + + saved_choice = str(profile.get("wake_word_choice") or "").strip() + current_wake_word_name = str( + profile.get("wake_word_name") or _template_default_string(substitutions.get("wake_word_name")) + ).strip() + current_wake_word_url = str(profile.get("wake_word_model_url") or "").strip() + + def match(predicate): + return next((row for row in trained_wake_words if predicate(row)), None) + + return ( + match(lambda row: str(row.get("key") or "") == saved_choice) + or match(lambda row: str(row.get("json_url") or "") == current_wake_word_url) + or match(lambda row: str(row.get("model_url") or "") == current_wake_word_url) + or match(lambda row: str(row.get("wake_word_name") or "") == current_wake_word_name) + or trained_wake_words[0] + ) + + +def _load_firmware_template_context(template_key: str, profile_key: str = "") -> Dict[str, Any]: spec = _firmware_template_spec(template_key) raw_text, source_label = _load_firmware_template_text(spec) parsed = yaml.load(raw_text, Loader=_FirmwareYamlLoader) @@ -1564,12 +1642,12 @@ def _load_firmware_template_context(template_key: str) -> Dict[str, Any]: "template_doc": parsed, "substitutions": dict(substitutions), "sections": _extract_substitution_sections(raw_text), - "profile": _load_firmware_profiles().get(str(spec.get("key") or ""), {}), + "profile": _load_firmware_profile(str(spec.get("key") or ""), profile_key), } -def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dict[str, Any]]: - ctx = _load_firmware_template_context(template_key) +def _firmware_template_fields(template_key: str, base_url: str = "", profile_key: str = "") -> List[Dict[str, Any]]: + ctx = _load_firmware_template_context(template_key, profile_key) spec = ctx["spec"] profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} fields: List[Dict[str, Any]] = [] @@ -1579,17 +1657,8 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic fixed_keys.add(identity_key) hidden_keys = {"ha_voice_ip"} | set(spec.get("auto_keys") or set()) trained_wake_words = _list_trained_wake_words(base_url) - current_wake_word_name = str( - profile.get("wake_word_name") or _template_default_string(ctx["substitutions"].get("wake_word_name")) - ).strip() - saved_wake_word_choice = str(profile.get("wake_word_choice") or "").strip() - selected_wake_word = next( - (row["key"] for row in trained_wake_words if row.get("key") == saved_wake_word_choice), - "", - ) or next( - (row["key"] for row in trained_wake_words if row.get("wake_word_name") == current_wake_word_name), - "", - ) + selected_wake_word_row = _selected_trained_wake_word(trained_wake_words, profile, ctx["substitutions"]) + selected_wake_word = str(selected_wake_word_row.get("key") or "") if selected_wake_word_row else "" wake_picker_added = False for key, raw_value in ctx["substitutions"].items(): @@ -1640,8 +1709,12 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic placeholder = "Your Wi-Fi SSID" description = "Required before build + flash." elif key_text == "wake_word_model_url": - placeholder = "https://.../wake_word.json" + value = str(selected_wake_word_row.get("json_url") or "") if selected_wake_word_row else "" + placeholder = "Train or select a local wake word first" + description = "Filled from the local trained wake-word picker." elif key_text == "wake_word_name": + if selected_wake_word_row: + value = str(selected_wake_word_row.get("wake_word_name") or selected_wake_word_row.get("key") or "") placeholder = "hey_tater" section = ctx["sections"].get(key_text) or "Firmware" if key_text in {"wake_word_name", "wake_word_model_url"}: @@ -1678,8 +1751,15 @@ def _esphome_pythonpath() -> str: return os.pathsep.join(paths) -def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str, session_id: str) -> tuple[Path, Dict[str, str]]: - ctx = _load_firmware_template_context(template_key) +def _render_firmware_config( + template_key: str, + values: Dict[str, Any], + host: str, + session_id: str, + port: Any = None, +) -> tuple[Path, Dict[str, str]]: + profile_key = _firmware_profile_key_for_target(host, port) + ctx = _load_firmware_template_context(template_key, profile_key) spec = ctx["spec"] profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} substitutions = ctx["substitutions"] @@ -1702,6 +1782,8 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str normalized[key_text] = "true" if bool(raw_value) else "false" elif key_text == "ha_voice_ip": normalized[key_text] = host + elif key_text == "wake_word_model_url": + normalized[key_text] = _local_trained_wake_word_url(raw_value) else: normalized[key_text] = str(raw_value if raw_value is not None else "").strip() or _template_default_string( substitutions.get(key_text) @@ -1714,6 +1796,8 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str missing.append("Wi-Fi password") if not host: missing.append("device IP or hostname") + if "wake_word_model_url" in substitutions and not normalized.get("wake_word_model_url"): + missing.append("local trained wake word") if missing: raise RuntimeError(f"Missing required firmware values: {', '.join(missing)}.") @@ -1728,7 +1812,11 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str session_dir.mkdir(parents=True, exist_ok=True) config_path = session_dir / f"{str(spec.get('key') or template_key)}.yaml" config_path.write_text(yaml.dump(config, Dumper=_FirmwareYamlDumper, sort_keys=False, allow_unicode=True), encoding="utf-8") - _save_firmware_profile(str(spec.get("key") or template_key), normalized) + normalized_host, normalized_port = _firmware_profile_target(host, port) + if normalized_host: + normalized["__target_host"] = normalized_host + normalized["__target_port"] = normalized_port + _save_firmware_profile(profile_key or str(spec.get("key") or template_key), normalized) return config_path, normalized @@ -1947,7 +2035,7 @@ def _run_firmware_build_flash_background(session_id: str): return try: - config_path, _normalized = _render_firmware_config(template_key, values, host, session_id) + config_path, _normalized = _render_firmware_config(template_key, values, host, session_id, port) except Exception as exc: _append_firmware_log(session_id, f"✗ Failed to prepare firmware config: {exc}") with FIRMWARE_LOCK: @@ -2518,28 +2606,30 @@ def firmware_devices(): @app.get("/api/firmware/templates") -def firmware_templates(request: Request): +def firmware_templates(request: Request, target_host: str = "", target_port: str = ""): templates = [] warnings = [] base_url = _request_base_url(request) wake_words = _list_trained_wake_words(base_url) - profiles = _load_firmware_profiles() + profile_key = _firmware_profile_key_for_target(target_host, target_port) + selected_host, selected_port = _firmware_profile_target(target_host, target_port) for spec in FIRMWARE_TEMPLATE_SPECS: key = str(spec.get("key") or "") - profile = profiles.get(key, {}) - target_port = str(profile.get("__target_port") or "") - if target_port == "6053": - target_port = str(FIRMWARE_DEFAULT_OTA_PORT) + profile = _load_firmware_profile(key, profile_key) + row_target_host = selected_host or str(profile.get("__target_host") or "") + row_target_port = selected_port or str(profile.get("__target_port") or "") + if row_target_port == "6053": + row_target_port = str(FIRMWARE_DEFAULT_OTA_PORT) row = { "value": key, "label": str(spec.get("label") or key), "source_url": _firmware_raw_url(str(spec.get("path") or "")), - "target_host": str(profile.get("__target_host") or ""), - "target_port": target_port, + "target_host": row_target_host, + "target_port": row_target_port, "fields": [], } try: - row["fields"] = _firmware_template_fields(key, base_url) + row["fields"] = _firmware_template_fields(key, base_url, profile_key) except Exception as exc: warnings.append(f"{row['label']}: {exc}") templates.append(row) @@ -2559,11 +2649,12 @@ def firmware_profile(payload: Dict[str, Any]): template_key = str(body.get("template_key") or "").strip() _firmware_template_spec(template_key) values = body.get("values") if isinstance(body.get("values"), dict) else {} - saved = _normalize_firmware_profile_update(template_key, values) - _save_firmware_profile(template_key, saved) + profile_key = _firmware_profile_key_for_target(values.get("__target_host"), values.get("__target_port")) + saved = _normalize_firmware_profile_update(template_key, values, profile_key) + _save_firmware_profile(profile_key or template_key, saved) except Exception as e: return JSONResponse({"ok": False, "error": str(e)}, status_code=400) - return {"ok": True, "template_key": template_key, "saved_fields": sorted(saved.keys())} + return {"ok": True, "template_key": template_key, "profile_key": profile_key or template_key, "saved_fields": sorted(saved.keys())} @app.get("/api/trained_wake_words/catalog")