mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
firmware url fixes
This commit is contained in:
@@ -1488,15 +1488,56 @@ def _load_firmware_profiles() -> Dict[str, Dict[str, str]]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_firmware_profile(template_key: str, values: Dict[str, str]) -> None:
|
||||
def _save_firmware_profile(profile_key: str, values: Dict[str, str]) -> None:
|
||||
FIRMWARE_PROFILE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
profiles = _load_firmware_profiles()
|
||||
profiles[template_key] = {str(key): str(value) for key, value in values.items() if str(key)}
|
||||
profiles[profile_key] = {str(key): str(value) for key, value in values.items() if str(key)}
|
||||
FIRMWARE_PROFILE_FILE.write_text(json.dumps(profiles, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
|
||||
def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]) -> Dict[str, str]:
|
||||
ctx = _load_firmware_template_context(template_key)
|
||||
def _firmware_profile_target(raw_host: Any = "", raw_port: Any = "") -> tuple[str, str]:
|
||||
host = str(raw_host or "").strip()
|
||||
port = str(raw_port or "").strip()
|
||||
if "://" in host:
|
||||
parsed = urlparse(host)
|
||||
host = parsed.hostname or ""
|
||||
if not port and parsed.port:
|
||||
port = str(parsed.port)
|
||||
host = host.strip().strip("/")
|
||||
if host.count(":") == 1 and not port:
|
||||
maybe_host, maybe_port = host.rsplit(":", 1)
|
||||
if maybe_port.isdigit():
|
||||
host = maybe_host
|
||||
port = maybe_port
|
||||
host = host.strip("[]").strip().lower()
|
||||
if not host:
|
||||
return "", ""
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_port = int(port or FIRMWARE_DEFAULT_OTA_PORT)
|
||||
if parsed_port == 6053:
|
||||
parsed_port = FIRMWARE_DEFAULT_OTA_PORT
|
||||
port = str(parsed_port)
|
||||
if not port:
|
||||
port = str(FIRMWARE_DEFAULT_OTA_PORT)
|
||||
return host, port
|
||||
|
||||
|
||||
def _firmware_profile_key_for_target(raw_host: Any = "", raw_port: Any = "") -> str:
|
||||
host, port = _firmware_profile_target(raw_host, raw_port)
|
||||
return f"device:{host}:{port}" if host else ""
|
||||
|
||||
|
||||
def _load_firmware_profile(template_key: str, profile_key: str = "") -> Dict[str, str]:
|
||||
profiles = _load_firmware_profiles()
|
||||
profile = profiles.get(profile_key) if profile_key else None
|
||||
if isinstance(profile, dict):
|
||||
return dict(profile)
|
||||
legacy = profiles.get(template_key)
|
||||
return dict(legacy) if isinstance(legacy, dict) else {}
|
||||
|
||||
|
||||
def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any], profile_key: str = "") -> Dict[str, str]:
|
||||
ctx = _load_firmware_template_context(template_key, profile_key)
|
||||
spec = ctx["spec"]
|
||||
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
|
||||
substitutions = ctx["substitutions"]
|
||||
@@ -1527,6 +1568,12 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
|
||||
elif key_text not in normalized:
|
||||
normalized[key_text] = "true" if str(default).lower() == "true" else "false"
|
||||
continue
|
||||
if key_text == "wake_word_model_url":
|
||||
if key_text in values:
|
||||
normalized[key_text] = _local_trained_wake_word_url(values.get(key_text))
|
||||
elif key_text not in normalized:
|
||||
normalized[key_text] = ""
|
||||
continue
|
||||
if key_text in values:
|
||||
normalized[key_text] = str(values.get(key_text) or "").strip()
|
||||
elif key_text not in normalized:
|
||||
@@ -1548,7 +1595,38 @@ def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any]
|
||||
return normalized
|
||||
|
||||
|
||||
def _load_firmware_template_context(template_key: str) -> Dict[str, Any]:
|
||||
def _local_trained_wake_word_url(value: Any) -> str:
|
||||
url = str(value or "").strip()
|
||||
return url if "/api/trained_wake_words/" in url else ""
|
||||
|
||||
|
||||
def _selected_trained_wake_word(
|
||||
trained_wake_words: List[Dict[str, Any]],
|
||||
profile: Dict[str, Any],
|
||||
substitutions: Dict[str, Any],
|
||||
) -> Dict[str, Any] | None:
|
||||
if not trained_wake_words:
|
||||
return None
|
||||
|
||||
saved_choice = str(profile.get("wake_word_choice") or "").strip()
|
||||
current_wake_word_name = str(
|
||||
profile.get("wake_word_name") or _template_default_string(substitutions.get("wake_word_name"))
|
||||
).strip()
|
||||
current_wake_word_url = str(profile.get("wake_word_model_url") or "").strip()
|
||||
|
||||
def match(predicate):
|
||||
return next((row for row in trained_wake_words if predicate(row)), None)
|
||||
|
||||
return (
|
||||
match(lambda row: str(row.get("key") or "") == saved_choice)
|
||||
or match(lambda row: str(row.get("json_url") or "") == current_wake_word_url)
|
||||
or match(lambda row: str(row.get("model_url") or "") == current_wake_word_url)
|
||||
or match(lambda row: str(row.get("wake_word_name") or "") == current_wake_word_name)
|
||||
or trained_wake_words[0]
|
||||
)
|
||||
|
||||
|
||||
def _load_firmware_template_context(template_key: str, profile_key: str = "") -> Dict[str, Any]:
|
||||
spec = _firmware_template_spec(template_key)
|
||||
raw_text, source_label = _load_firmware_template_text(spec)
|
||||
parsed = yaml.load(raw_text, Loader=_FirmwareYamlLoader)
|
||||
@@ -1564,12 +1642,12 @@ def _load_firmware_template_context(template_key: str) -> Dict[str, Any]:
|
||||
"template_doc": parsed,
|
||||
"substitutions": dict(substitutions),
|
||||
"sections": _extract_substitution_sections(raw_text),
|
||||
"profile": _load_firmware_profiles().get(str(spec.get("key") or ""), {}),
|
||||
"profile": _load_firmware_profile(str(spec.get("key") or ""), profile_key),
|
||||
}
|
||||
|
||||
|
||||
def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dict[str, Any]]:
|
||||
ctx = _load_firmware_template_context(template_key)
|
||||
def _firmware_template_fields(template_key: str, base_url: str = "", profile_key: str = "") -> List[Dict[str, Any]]:
|
||||
ctx = _load_firmware_template_context(template_key, profile_key)
|
||||
spec = ctx["spec"]
|
||||
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
|
||||
fields: List[Dict[str, Any]] = []
|
||||
@@ -1579,17 +1657,8 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
|
||||
fixed_keys.add(identity_key)
|
||||
hidden_keys = {"ha_voice_ip"} | set(spec.get("auto_keys") or set())
|
||||
trained_wake_words = _list_trained_wake_words(base_url)
|
||||
current_wake_word_name = str(
|
||||
profile.get("wake_word_name") or _template_default_string(ctx["substitutions"].get("wake_word_name"))
|
||||
).strip()
|
||||
saved_wake_word_choice = str(profile.get("wake_word_choice") or "").strip()
|
||||
selected_wake_word = next(
|
||||
(row["key"] for row in trained_wake_words if row.get("key") == saved_wake_word_choice),
|
||||
"",
|
||||
) or next(
|
||||
(row["key"] for row in trained_wake_words if row.get("wake_word_name") == current_wake_word_name),
|
||||
"",
|
||||
)
|
||||
selected_wake_word_row = _selected_trained_wake_word(trained_wake_words, profile, ctx["substitutions"])
|
||||
selected_wake_word = str(selected_wake_word_row.get("key") or "") if selected_wake_word_row else ""
|
||||
wake_picker_added = False
|
||||
|
||||
for key, raw_value in ctx["substitutions"].items():
|
||||
@@ -1640,8 +1709,12 @@ def _firmware_template_fields(template_key: str, base_url: str = "") -> List[Dic
|
||||
placeholder = "Your Wi-Fi SSID"
|
||||
description = "Required before build + flash."
|
||||
elif key_text == "wake_word_model_url":
|
||||
placeholder = "https://.../wake_word.json"
|
||||
value = str(selected_wake_word_row.get("json_url") or "") if selected_wake_word_row else ""
|
||||
placeholder = "Train or select a local wake word first"
|
||||
description = "Filled from the local trained wake-word picker."
|
||||
elif key_text == "wake_word_name":
|
||||
if selected_wake_word_row:
|
||||
value = str(selected_wake_word_row.get("wake_word_name") or selected_wake_word_row.get("key") or "")
|
||||
placeholder = "hey_tater"
|
||||
section = ctx["sections"].get(key_text) or "Firmware"
|
||||
if key_text in {"wake_word_name", "wake_word_model_url"}:
|
||||
@@ -1678,8 +1751,15 @@ def _esphome_pythonpath() -> str:
|
||||
return os.pathsep.join(paths)
|
||||
|
||||
|
||||
def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str, session_id: str) -> tuple[Path, Dict[str, str]]:
|
||||
ctx = _load_firmware_template_context(template_key)
|
||||
def _render_firmware_config(
|
||||
template_key: str,
|
||||
values: Dict[str, Any],
|
||||
host: str,
|
||||
session_id: str,
|
||||
port: Any = None,
|
||||
) -> tuple[Path, Dict[str, str]]:
|
||||
profile_key = _firmware_profile_key_for_target(host, port)
|
||||
ctx = _load_firmware_template_context(template_key, profile_key)
|
||||
spec = ctx["spec"]
|
||||
profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {}
|
||||
substitutions = ctx["substitutions"]
|
||||
@@ -1702,6 +1782,8 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
|
||||
normalized[key_text] = "true" if bool(raw_value) else "false"
|
||||
elif key_text == "ha_voice_ip":
|
||||
normalized[key_text] = host
|
||||
elif key_text == "wake_word_model_url":
|
||||
normalized[key_text] = _local_trained_wake_word_url(raw_value)
|
||||
else:
|
||||
normalized[key_text] = str(raw_value if raw_value is not None else "").strip() or _template_default_string(
|
||||
substitutions.get(key_text)
|
||||
@@ -1714,6 +1796,8 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
|
||||
missing.append("Wi-Fi password")
|
||||
if not host:
|
||||
missing.append("device IP or hostname")
|
||||
if "wake_word_model_url" in substitutions and not normalized.get("wake_word_model_url"):
|
||||
missing.append("local trained wake word")
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required firmware values: {', '.join(missing)}.")
|
||||
|
||||
@@ -1728,7 +1812,11 @@ def _render_firmware_config(template_key: str, values: Dict[str, Any], host: str
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
config_path = session_dir / f"{str(spec.get('key') or template_key)}.yaml"
|
||||
config_path.write_text(yaml.dump(config, Dumper=_FirmwareYamlDumper, sort_keys=False, allow_unicode=True), encoding="utf-8")
|
||||
_save_firmware_profile(str(spec.get("key") or template_key), normalized)
|
||||
normalized_host, normalized_port = _firmware_profile_target(host, port)
|
||||
if normalized_host:
|
||||
normalized["__target_host"] = normalized_host
|
||||
normalized["__target_port"] = normalized_port
|
||||
_save_firmware_profile(profile_key or str(spec.get("key") or template_key), normalized)
|
||||
return config_path, normalized
|
||||
|
||||
|
||||
@@ -1947,7 +2035,7 @@ def _run_firmware_build_flash_background(session_id: str):
|
||||
return
|
||||
|
||||
try:
|
||||
config_path, _normalized = _render_firmware_config(template_key, values, host, session_id)
|
||||
config_path, _normalized = _render_firmware_config(template_key, values, host, session_id, port)
|
||||
except Exception as exc:
|
||||
_append_firmware_log(session_id, f"✗ Failed to prepare firmware config: {exc}")
|
||||
with FIRMWARE_LOCK:
|
||||
@@ -2518,28 +2606,30 @@ def firmware_devices():
|
||||
|
||||
|
||||
@app.get("/api/firmware/templates")
|
||||
def firmware_templates(request: Request):
|
||||
def firmware_templates(request: Request, target_host: str = "", target_port: str = ""):
|
||||
templates = []
|
||||
warnings = []
|
||||
base_url = _request_base_url(request)
|
||||
wake_words = _list_trained_wake_words(base_url)
|
||||
profiles = _load_firmware_profiles()
|
||||
profile_key = _firmware_profile_key_for_target(target_host, target_port)
|
||||
selected_host, selected_port = _firmware_profile_target(target_host, target_port)
|
||||
for spec in FIRMWARE_TEMPLATE_SPECS:
|
||||
key = str(spec.get("key") or "")
|
||||
profile = profiles.get(key, {})
|
||||
target_port = str(profile.get("__target_port") or "")
|
||||
if target_port == "6053":
|
||||
target_port = str(FIRMWARE_DEFAULT_OTA_PORT)
|
||||
profile = _load_firmware_profile(key, profile_key)
|
||||
row_target_host = selected_host or str(profile.get("__target_host") or "")
|
||||
row_target_port = selected_port or str(profile.get("__target_port") or "")
|
||||
if row_target_port == "6053":
|
||||
row_target_port = str(FIRMWARE_DEFAULT_OTA_PORT)
|
||||
row = {
|
||||
"value": key,
|
||||
"label": str(spec.get("label") or key),
|
||||
"source_url": _firmware_raw_url(str(spec.get("path") or "")),
|
||||
"target_host": str(profile.get("__target_host") or ""),
|
||||
"target_port": target_port,
|
||||
"target_host": row_target_host,
|
||||
"target_port": row_target_port,
|
||||
"fields": [],
|
||||
}
|
||||
try:
|
||||
row["fields"] = _firmware_template_fields(key, base_url)
|
||||
row["fields"] = _firmware_template_fields(key, base_url, profile_key)
|
||||
except Exception as exc:
|
||||
warnings.append(f"{row['label']}: {exc}")
|
||||
templates.append(row)
|
||||
@@ -2559,11 +2649,12 @@ def firmware_profile(payload: Dict[str, Any]):
|
||||
template_key = str(body.get("template_key") or "").strip()
|
||||
_firmware_template_spec(template_key)
|
||||
values = body.get("values") if isinstance(body.get("values"), dict) else {}
|
||||
saved = _normalize_firmware_profile_update(template_key, values)
|
||||
_save_firmware_profile(template_key, saved)
|
||||
profile_key = _firmware_profile_key_for_target(values.get("__target_host"), values.get("__target_port"))
|
||||
saved = _normalize_firmware_profile_update(template_key, values, profile_key)
|
||||
_save_firmware_profile(profile_key or template_key, saved)
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=400)
|
||||
return {"ok": True, "template_key": template_key, "saved_fields": sorted(saved.keys())}
|
||||
return {"ok": True, "template_key": template_key, "profile_key": profile_key or template_key, "saved_fields": sorted(saved.keys())}
|
||||
|
||||
|
||||
@app.get("/api/trained_wake_words/catalog")
|
||||
|
||||
Reference in New Issue
Block a user