mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
model name
This commit is contained in:
@@ -1,10 +1,13 @@
|
|||||||
# recorder_server.py
|
# recorder_server.py
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import FastAPI, UploadFile, File, Form
|
from fastapi import FastAPI, UploadFile, File, Form
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
@@ -71,7 +74,7 @@ STATE: Dict[str, Any] = {
|
|||||||
"log_path": None, # path to recorder_training.log
|
"log_path": None, # path to recorder_training.log
|
||||||
"safe_word": None,
|
"safe_word": None,
|
||||||
|
|
||||||
# NEW: prevent UI duplication when UI appends:
|
# prevent UI duplication when UI appends:
|
||||||
"last_sent_tail": [], # last tail snapshot (list of lines)
|
"last_sent_tail": [], # last tail snapshot (list of lines)
|
||||||
"last_log_size": 0, # detect truncation
|
"last_log_size": 0, # detect truncation
|
||||||
},
|
},
|
||||||
@@ -258,6 +261,114 @@ def _compute_new_lines(prev_tail: List[str], new_tail: List[str]) -> List[str]:
|
|||||||
return new_tail
|
return new_tail
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- output artifact normalization --------------------
|
||||||
|
|
||||||
|
def _find_latest_output_pair(output_dir: Path) -> Tuple[Optional[Path], Optional[Path]]:
|
||||||
|
"""
|
||||||
|
Find the most recently modified .tflite and its matching .json (same basename)
|
||||||
|
in output_dir. Falls back to newest .json if an exact match doesn't exist.
|
||||||
|
Returns (tflite_path, json_path) or (None, None).
|
||||||
|
"""
|
||||||
|
if not output_dir.exists():
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
tflites = sorted(output_dir.glob("*.tflite"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
if not tflites:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
tfl = tflites[0]
|
||||||
|
js = tfl.with_suffix(".json")
|
||||||
|
if js.exists():
|
||||||
|
return (tfl, js)
|
||||||
|
|
||||||
|
jsons = sorted(output_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
return (tfl, jsons[0] if jsons else None)
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_replace_strings(obj: Any, old: str, new: str) -> Any:
|
||||||
|
"""
|
||||||
|
Recursively replace occurrences of old in any string values with new.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return obj.replace(old, new)
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_deep_replace_strings(x, old, new) for x in obj]
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _deep_replace_strings(v, old, new) for k, v in obj.items()}
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_output_artifacts(safe_word: str, log_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Rename output artifacts to <safe_word>.tflite / <safe_word>.json
|
||||||
|
and patch the JSON so it references the renamed tflite.
|
||||||
|
|
||||||
|
Handles weird trainer names like ____r_.tflite by normalizing post-run.
|
||||||
|
"""
|
||||||
|
out_dir = DATA_DIR / "output"
|
||||||
|
tfl, js = _find_latest_output_pair(out_dir)
|
||||||
|
|
||||||
|
if not tfl:
|
||||||
|
_append_train_log(f"⚠️ No .tflite found in {out_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
new_tfl = out_dir / f"{safe_word}.tflite"
|
||||||
|
new_js = out_dir / f"{safe_word}.json"
|
||||||
|
old_tfl_name = tfl.name
|
||||||
|
|
||||||
|
# Already normalized
|
||||||
|
if tfl.name == new_tfl.name and (js and js.name == new_js.name):
|
||||||
|
_append_train_log(f"✅ Output names already normalized: {new_tfl.name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
def backup_if_exists(p: Path, suffix: str) -> None:
|
||||||
|
if p.exists():
|
||||||
|
bk = out_dir / f"{safe_word}.{ts}.bak{suffix}"
|
||||||
|
shutil.move(str(p), str(bk))
|
||||||
|
_append_train_log(f"↪️ Backed up existing {p.name} → {bk.name}")
|
||||||
|
|
||||||
|
# Avoid clobbering existing target files (back them up)
|
||||||
|
if new_tfl.exists() and new_tfl.resolve() != tfl.resolve():
|
||||||
|
backup_if_exists(new_tfl, ".tflite")
|
||||||
|
if new_js.exists() and (not js or new_js.resolve() != js.resolve()):
|
||||||
|
backup_if_exists(new_js, ".json")
|
||||||
|
|
||||||
|
# Rename tflite
|
||||||
|
if tfl.resolve() != new_tfl.resolve():
|
||||||
|
new_tfl.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(tfl), str(new_tfl))
|
||||||
|
_append_train_log(f"✅ Renamed model: {old_tfl_name} → {new_tfl.name}")
|
||||||
|
|
||||||
|
# Rename + patch json if present
|
||||||
|
if js and js.exists():
|
||||||
|
# Read JSON before move (safer if we want the old name)
|
||||||
|
try:
|
||||||
|
data = json.loads(js.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
if js.resolve() != new_js.resolve():
|
||||||
|
shutil.move(str(js), str(new_js))
|
||||||
|
_append_train_log(f"✅ Renamed metadata: {js.name} → {new_js.name}")
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
patched = _deep_replace_strings(data, old_tfl_name, new_tfl.name)
|
||||||
|
|
||||||
|
# Patch common keys if present
|
||||||
|
for key in ("model", "model_file", "model_filename", "tflite", "tflite_file", "tflite_filename"):
|
||||||
|
if isinstance(patched, dict) and key in patched and isinstance(patched[key], str):
|
||||||
|
patched[key] = new_tfl.name
|
||||||
|
|
||||||
|
new_js.write_text(json.dumps(patched, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||||
|
_append_train_log(f"✅ Patched JSON to reference: {new_tfl.name}")
|
||||||
|
else:
|
||||||
|
_append_train_log("⚠️ No .json found to patch (model renamed only)")
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- training worker --------------------
|
||||||
|
|
||||||
def _run_training_background(safe_word: str, allow_no_personal: bool):
|
def _run_training_background(safe_word: str, allow_no_personal: bool):
|
||||||
with STATE_LOCK:
|
with STATE_LOCK:
|
||||||
raw_phrase = STATE.get("raw_phrase") or ""
|
raw_phrase = STATE.get("raw_phrase") or ""
|
||||||
@@ -324,6 +435,10 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
|
|||||||
with STATE_LOCK:
|
with STATE_LOCK:
|
||||||
STATE["training"]["exit_code"] = rc
|
STATE["training"]["exit_code"] = rc
|
||||||
|
|
||||||
|
# Normalize output artifact names on success
|
||||||
|
if rc == 0:
|
||||||
|
_normalize_output_artifacts(safe_word, log_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_append_train_log(f"✗ Training crashed: {e!r}")
|
_append_train_log(f"✗ Training crashed: {e!r}")
|
||||||
with STATE_LOCK:
|
with STATE_LOCK:
|
||||||
|
|||||||
Reference in New Issue
Block a user