diff --git a/recorder_server.py b/recorder_server.py index cdfebe5..b0f82f0 100644 --- a/recorder_server.py +++ b/recorder_server.py @@ -1,10 +1,13 @@ # recorder_server.py import os import re +import json +import shutil import subprocess import threading +from datetime import datetime from pathlib import Path -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Tuple from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import HTMLResponse, JSONResponse @@ -71,7 +74,7 @@ STATE: Dict[str, Any] = { "log_path": None, # path to recorder_training.log "safe_word": None, - # NEW: prevent UI duplication when UI appends: + # prevent UI duplication when UI appends: "last_sent_tail": [], # last tail snapshot (list of lines) "last_log_size": 0, # detect truncation }, @@ -258,6 +261,114 @@ def _compute_new_lines(prev_tail: List[str], new_tail: List[str]) -> List[str]: return new_tail +# -------------------- output artifact normalization -------------------- + +def _find_latest_output_pair(output_dir: Path) -> Tuple[Optional[Path], Optional[Path]]: + """ + Find the most recently modified .tflite and its matching .json (same basename) + in output_dir. Falls back to newest .json if an exact match doesn't exist. + Returns (tflite_path, json_path) or (None, None). + """ + if not output_dir.exists(): + return (None, None) + + tflites = sorted(output_dir.glob("*.tflite"), key=lambda p: p.stat().st_mtime, reverse=True) + if not tflites: + return (None, None) + + tfl = tflites[0] + js = tfl.with_suffix(".json") + if js.exists(): + return (tfl, js) + + jsons = sorted(output_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True) + return (tfl, jsons[0] if jsons else None) + + +def _deep_replace_strings(obj: Any, old: str, new: str) -> Any: + """ + Recursively replace occurrences of old in any string values with new. + """ + if isinstance(obj, str): + return obj.replace(old, new) + if isinstance(obj, list): + return [_deep_replace_strings(x, old, new) for x in obj] + if isinstance(obj, dict): + return {k: _deep_replace_strings(v, old, new) for k, v in obj.items()} + return obj + + +def _normalize_output_artifacts(safe_word: str, log_path: Path) -> None: + """ + Rename output artifacts to .tflite / .json + and patch the JSON so it references the renamed tflite. + + Handles weird trainer names like ____r_.tflite by normalizing post-run. + """ + out_dir = DATA_DIR / "output" + tfl, js = _find_latest_output_pair(out_dir) + + if not tfl: + _append_train_log(f"⚠️ No .tflite found in {out_dir}") + return + + new_tfl = out_dir / f"{safe_word}.tflite" + new_js = out_dir / f"{safe_word}.json" + old_tfl_name = tfl.name + + # Already normalized + if tfl.name == new_tfl.name and (js and js.name == new_js.name): + _append_train_log(f"✅ Output names already normalized: {new_tfl.name}") + return + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + + def backup_if_exists(p: Path, suffix: str) -> None: + if p.exists(): + bk = out_dir / f"{safe_word}.{ts}.bak{suffix}" + shutil.move(str(p), str(bk)) + _append_train_log(f"↪️ Backed up existing {p.name} → {bk.name}") + + # Avoid clobbering existing target files (back them up) + if new_tfl.exists() and new_tfl.resolve() != tfl.resolve(): + backup_if_exists(new_tfl, ".tflite") + if new_js.exists() and (not js or new_js.resolve() != js.resolve()): + backup_if_exists(new_js, ".json") + + # Rename tflite + if tfl.resolve() != new_tfl.resolve(): + new_tfl.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(tfl), str(new_tfl)) + _append_train_log(f"✅ Renamed model: {old_tfl_name} → {new_tfl.name}") + + # Rename + patch json if present + if js and js.exists(): + # Read JSON before move (safer if we want the old name) + try: + data = json.loads(js.read_text(encoding="utf-8")) + except Exception: + data = None + + if js.resolve() != new_js.resolve(): + shutil.move(str(js), str(new_js)) + _append_train_log(f"✅ Renamed metadata: {js.name} → {new_js.name}") + + if data is not None: + patched = _deep_replace_strings(data, old_tfl_name, new_tfl.name) + + # Patch common keys if present + for key in ("model", "model_file", "model_filename", "tflite", "tflite_file", "tflite_filename"): + if isinstance(patched, dict) and key in patched and isinstance(patched[key], str): + patched[key] = new_tfl.name + + new_js.write_text(json.dumps(patched, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + _append_train_log(f"✅ Patched JSON to reference: {new_tfl.name}") + else: + _append_train_log("⚠️ No .json found to patch (model renamed only)") + + +# -------------------- training worker -------------------- + def _run_training_background(safe_word: str, allow_no_personal: bool): with STATE_LOCK: raw_phrase = STATE.get("raw_phrase") or "" @@ -324,6 +435,10 @@ def _run_training_background(safe_word: str, allow_no_personal: bool): with STATE_LOCK: STATE["training"]["exit_code"] = rc + # Normalize output artifact names on success + if rc == 0: + _normalize_output_artifacts(safe_word, log_path) + except Exception as e: _append_train_log(f"✗ Training crashed: {e!r}") with STATE_LOCK: