mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
model name
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
# recorder_server.py
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
@@ -71,7 +74,7 @@ STATE: Dict[str, Any] = {
|
||||
"log_path": None, # path to recorder_training.log
|
||||
"safe_word": None,
|
||||
|
||||
# NEW: prevent UI duplication when UI appends:
|
||||
# prevent UI duplication when UI appends:
|
||||
"last_sent_tail": [], # last tail snapshot (list of lines)
|
||||
"last_log_size": 0, # detect truncation
|
||||
},
|
||||
@@ -258,6 +261,114 @@ def _compute_new_lines(prev_tail: List[str], new_tail: List[str]) -> List[str]:
|
||||
return new_tail
|
||||
|
||||
|
||||
# -------------------- output artifact normalization --------------------
|
||||
|
||||
def _find_latest_output_pair(output_dir: Path) -> Tuple[Optional[Path], Optional[Path]]:
|
||||
"""
|
||||
Find the most recently modified .tflite and its matching .json (same basename)
|
||||
in output_dir. Falls back to newest .json if an exact match doesn't exist.
|
||||
Returns (tflite_path, json_path) or (None, None).
|
||||
"""
|
||||
if not output_dir.exists():
|
||||
return (None, None)
|
||||
|
||||
tflites = sorted(output_dir.glob("*.tflite"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if not tflites:
|
||||
return (None, None)
|
||||
|
||||
tfl = tflites[0]
|
||||
js = tfl.with_suffix(".json")
|
||||
if js.exists():
|
||||
return (tfl, js)
|
||||
|
||||
jsons = sorted(output_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
return (tfl, jsons[0] if jsons else None)
|
||||
|
||||
|
||||
def _deep_replace_strings(obj: Any, old: str, new: str) -> Any:
|
||||
"""
|
||||
Recursively replace occurrences of old in any string values with new.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return obj.replace(old, new)
|
||||
if isinstance(obj, list):
|
||||
return [_deep_replace_strings(x, old, new) for x in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {k: _deep_replace_strings(v, old, new) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
def _normalize_output_artifacts(safe_word: str, log_path: Path) -> None:
|
||||
"""
|
||||
Rename output artifacts to <safe_word>.tflite / <safe_word>.json
|
||||
and patch the JSON so it references the renamed tflite.
|
||||
|
||||
Handles weird trainer names like ____r_.tflite by normalizing post-run.
|
||||
"""
|
||||
out_dir = DATA_DIR / "output"
|
||||
tfl, js = _find_latest_output_pair(out_dir)
|
||||
|
||||
if not tfl:
|
||||
_append_train_log(f"⚠️ No .tflite found in {out_dir}")
|
||||
return
|
||||
|
||||
new_tfl = out_dir / f"{safe_word}.tflite"
|
||||
new_js = out_dir / f"{safe_word}.json"
|
||||
old_tfl_name = tfl.name
|
||||
|
||||
# Already normalized
|
||||
if tfl.name == new_tfl.name and (js and js.name == new_js.name):
|
||||
_append_train_log(f"✅ Output names already normalized: {new_tfl.name}")
|
||||
return
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
def backup_if_exists(p: Path, suffix: str) -> None:
|
||||
if p.exists():
|
||||
bk = out_dir / f"{safe_word}.{ts}.bak{suffix}"
|
||||
shutil.move(str(p), str(bk))
|
||||
_append_train_log(f"↪️ Backed up existing {p.name} → {bk.name}")
|
||||
|
||||
# Avoid clobbering existing target files (back them up)
|
||||
if new_tfl.exists() and new_tfl.resolve() != tfl.resolve():
|
||||
backup_if_exists(new_tfl, ".tflite")
|
||||
if new_js.exists() and (not js or new_js.resolve() != js.resolve()):
|
||||
backup_if_exists(new_js, ".json")
|
||||
|
||||
# Rename tflite
|
||||
if tfl.resolve() != new_tfl.resolve():
|
||||
new_tfl.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(tfl), str(new_tfl))
|
||||
_append_train_log(f"✅ Renamed model: {old_tfl_name} → {new_tfl.name}")
|
||||
|
||||
# Rename + patch json if present
|
||||
if js and js.exists():
|
||||
# Read JSON before move (safer if we want the old name)
|
||||
try:
|
||||
data = json.loads(js.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
data = None
|
||||
|
||||
if js.resolve() != new_js.resolve():
|
||||
shutil.move(str(js), str(new_js))
|
||||
_append_train_log(f"✅ Renamed metadata: {js.name} → {new_js.name}")
|
||||
|
||||
if data is not None:
|
||||
patched = _deep_replace_strings(data, old_tfl_name, new_tfl.name)
|
||||
|
||||
# Patch common keys if present
|
||||
for key in ("model", "model_file", "model_filename", "tflite", "tflite_file", "tflite_filename"):
|
||||
if isinstance(patched, dict) and key in patched and isinstance(patched[key], str):
|
||||
patched[key] = new_tfl.name
|
||||
|
||||
new_js.write_text(json.dumps(patched, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
_append_train_log(f"✅ Patched JSON to reference: {new_tfl.name}")
|
||||
else:
|
||||
_append_train_log("⚠️ No .json found to patch (model renamed only)")
|
||||
|
||||
|
||||
# -------------------- training worker --------------------
|
||||
|
||||
def _run_training_background(safe_word: str, allow_no_personal: bool):
|
||||
with STATE_LOCK:
|
||||
raw_phrase = STATE.get("raw_phrase") or ""
|
||||
@@ -324,6 +435,10 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
|
||||
with STATE_LOCK:
|
||||
STATE["training"]["exit_code"] = rc
|
||||
|
||||
# Normalize output artifact names on success
|
||||
if rc == 0:
|
||||
_normalize_output_artifacts(safe_word, log_path)
|
||||
|
||||
except Exception as e:
|
||||
_append_train_log(f"✗ Training crashed: {e!r}")
|
||||
with STATE_LOCK:
|
||||
|
||||
Reference in New Issue
Block a user