diff --git a/trainer_server.py b/trainer_server.py index 5c9fe6b..c82b0eb 100644 --- a/trainer_server.py +++ b/trainer_server.py @@ -1619,6 +1619,14 @@ def _firmware_profile_key_for_target(raw_host: Any = "", raw_port: Any = "") -> return f"device:{host}:{port}" if host else "" +def _firmware_profile_key(template_key: Any = "", raw_host: Any = "", raw_port: Any = "") -> str: + target_key = _firmware_profile_key_for_target(raw_host, raw_port) + template = str(template_key or "").strip().lower() + if target_key and template: + return f"{target_key}:template:{template}" + return target_key or template + + def _firmware_cache_slug(*parts: Any) -> str: raw = "__".join(str(part or "").strip() for part in parts if str(part or "").strip()) slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", raw).strip("._-") @@ -1656,16 +1664,31 @@ def _load_firmware_profile(template_key: str, profile_key: str = "") -> Dict[str profile = profiles.get(profile_key) if profile_key else None if isinstance(profile, dict): return dict(profile) + if profile_key and ":template:" in profile_key: + legacy_device_key = profile_key.split(":template:", 1)[0] + legacy_device = profiles.get(legacy_device_key) + if isinstance(legacy_device, dict): + return dict(legacy_device) legacy = profiles.get(template_key) return dict(legacy) if isinstance(legacy, dict) else {} +def _firmware_profile_values_for_template(profile: Dict[str, Any], substitutions: Dict[str, Any]) -> Dict[str, str]: + keep_keys = {str(key or "").strip() for key in substitutions.keys()} + keep_keys.update({"__target_host", "__target_port", "wake_sound_catalog", "wake_word_choice"}) + return { + key: str(profile.get(key) or "") + for key in keep_keys + if key and key in profile + } + + def _normalize_firmware_profile_update(template_key: str, values: Dict[str, Any], profile_key: str = "") -> Dict[str, str]: ctx = _load_firmware_template_context(template_key, profile_key) spec = ctx["spec"] profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} substitutions = ctx["substitutions"] - normalized: Dict[str, str] = dict(profile) + normalized = _firmware_profile_values_for_template(profile, substitutions) fixed_keys = set(spec.get("fixed_keys") or set()) identity_key = str(spec.get("identity_key") or "").strip() if identity_key: @@ -1923,12 +1946,12 @@ def _render_firmware_config( session_id: str, port: Any = None, ) -> tuple[Path, Dict[str, str], Path]: - profile_key = _firmware_profile_key_for_target(host, port) + profile_key = _firmware_profile_key(template_key, host, port) ctx = _load_firmware_template_context(template_key, profile_key) spec = ctx["spec"] profile = ctx.get("profile") if isinstance(ctx.get("profile"), dict) else {} substitutions = ctx["substitutions"] - normalized: Dict[str, str] = dict(profile) + normalized = _firmware_profile_values_for_template(profile, substitutions) fixed_keys = set(spec.get("fixed_keys") or set()) identity_key = str(spec.get("identity_key") or "").strip() if identity_key: @@ -2792,10 +2815,10 @@ def firmware_templates(request: Request, target_host: str = "", target_port: str warnings = [] base_url = _request_base_url(request) wake_words = _list_trained_wake_words(base_url) - profile_key = _firmware_profile_key_for_target(target_host, target_port) selected_host, selected_port = _firmware_profile_target(target_host, target_port) for spec in FIRMWARE_TEMPLATE_SPECS: key = str(spec.get("key") or "") + profile_key = _firmware_profile_key(key, target_host, target_port) profile = _load_firmware_profile(key, profile_key) row_target_host = selected_host or str(profile.get("__target_host") or "") row_target_port = selected_port or str(profile.get("__target_port") or "") @@ -2830,7 +2853,7 @@ def firmware_profile(payload: Dict[str, Any]): template_key = str(body.get("template_key") or "").strip() _firmware_template_spec(template_key) values = body.get("values") if isinstance(body.get("values"), dict) else {} - profile_key = _firmware_profile_key_for_target(values.get("__target_host"), values.get("__target_port")) + profile_key = _firmware_profile_key(template_key, values.get("__target_host"), values.get("__target_port")) saved = _normalize_firmware_profile_update(template_key, values, profile_key) _save_firmware_profile(profile_key or template_key, saved) except Exception as e: