mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Automatic Calibration
This commit is contained in:
@@ -317,6 +317,7 @@ fi
|
||||
|
||||
TRAINING_DONE="false"
|
||||
|
||||
echo "🏋️ Starting model training and TFLite export (this is the longest stage)…"
|
||||
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
|
||||
echo "✅ Training complete (GPU path)."
|
||||
TRAINING_DONE="true"
|
||||
@@ -386,12 +387,24 @@ if [ "${TRAINING_DONE}" != "true" ]; then
|
||||
fi
|
||||
|
||||
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||
calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json"
|
||||
|
||||
if [ ! -f "${source_path}" ] ; then
|
||||
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🎯 Calibrating detector settings for on-device use…"
|
||||
if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \
|
||||
--training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \
|
||||
--model "${source_path}" \
|
||||
--output "${calibration_path}"; then
|
||||
echo "✅ Detector calibration complete."
|
||||
else
|
||||
echo "⚠️ Detector calibration failed; packaging with default detector settings."
|
||||
rm -f "${calibration_path}" || :
|
||||
fi
|
||||
|
||||
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
||||
@@ -404,24 +417,49 @@ tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||
cp "${source_path}" "${tflite_path}"
|
||||
|
||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||
cat <<-EOF > "${json_path}"
|
||||
{
|
||||
export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}"
|
||||
echo "📦 Packaging final model artifacts…"
|
||||
"${PYTHON_BIN:-python}" - <<'PY'
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
json_path = Path(os.environ["JSON_PATH"])
|
||||
calibration_path = Path(os.environ.get("CALIBRATION_PATH", ""))
|
||||
language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower()
|
||||
probability_cutoff = 0.97
|
||||
sliding_window_size = 5
|
||||
|
||||
if calibration_path.exists():
|
||||
try:
|
||||
calibration = json.loads(calibration_path.read_text(encoding="utf-8"))
|
||||
probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff))
|
||||
sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size))
|
||||
print(
|
||||
f"🎯 Using calibrated detector settings: "
|
||||
f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}"
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.")
|
||||
|
||||
meta = {
|
||||
"type": "micro",
|
||||
"wake_word": "${WAKE_WORD_TITLE}",
|
||||
"wake_word": os.environ["WAKE_WORD_TITLE"],
|
||||
"author": "Tater Totterson",
|
||||
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
||||
"model": "${tflite_filename}",
|
||||
"trained_languages": ["en"],
|
||||
"model": os.environ["TFLITE_FILENAME"],
|
||||
"trained_languages": [language],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"sliding_window_size": 5,
|
||||
"probability_cutoff": round(probability_cutoff, 2),
|
||||
"sliding_window_size": sliding_window_size,
|
||||
"feature_step_size": 10,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
}
|
||||
"minimum_esphome_version": "2024.7.0",
|
||||
},
|
||||
}
|
||||
EOF
|
||||
json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8")
|
||||
PY
|
||||
|
||||
echo "Name: ${WAKE_WORD_TITLE}"
|
||||
echo "Model: ${tflite_path}"
|
||||
|
||||
Reference in New Issue
Block a user