#!/bin/bash
set -e

PROGPATH=$(realpath "$0")
PROGDIR=$(dirname "${PROGPATH}")

KNOWN_ARGS=( training-steps samples data-dir )
source "${PROGDIR}/shell.functions"
WAKE_WORD="${POSITIONAL_ARGS[0]}"

if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
    echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
    HELP=true
fi

if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
    cat <<EOF >&2
Usage: $0 [ --samples=<samples> ] [ --training-steps=<steps> ]
          <wake_word> [ <wake_word_title> ]

       $0 -h/--help

--samples:          The number of samples to generate for the wake word.
                    Used only to generate output file names.

--training-steps:   Number of training steps.
                    Default: ${DEFAULT_TRAINING_STEPS}

<wake_word>:        The word to train spelled phonetically.
                    Required.

<wake_word_title>:  A pretty name to save to the json metadata file.
                    Default: The wake word with individual words capitalized.

EOF
    exit 1
fi

WORK_DIR="${DATA_DIR}/work"
TRAINING_DS="${DATA_DIR}/training_datasets"

[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}"

if [ ! -v WAKE_WORD_TITLE ] ; then
    declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
    WAKE_WORD_TITLE="${WWNA[*]^}"
elif [ -z "$WAKE_WORD_TITLE" ] ; then
    WAKE_WORD_TITLE="$WAKE_WORD"
fi

# shellcheck source=/dev/null
source "${DATA_DIR}/.venv/bin/activate"

# Keep copies so we can restore/preserve across retries and fallback.
ORIG_XLA_FLAGS="${XLA_FLAGS:-}"
ORIG_TF_XLA_FLAGS="${TF_XLA_FLAGS:-}"

normalize_bool() {
  case "${1,,}" in
    1|true|yes|on) echo "true" ;;
    *) echo "false" ;;
  esac
}

detect_gpu_compute_capability() {
  if command -v nvidia-smi >/dev/null 2>&1; then
    nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
      | head -n 1 \
      | tr -d '[:space:]'
  fi
}

GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)"
IS_BLACKWELL="false"
case "${GPU_COMPUTE_CAPABILITY}" in
  12.*) IS_BLACKWELL="true" ;;
esac

ALLOW_CPU_FALLBACK_DEFAULT="true"
ALLOW_CPU_FALLBACK="$(normalize_bool "${MWW_ALLOW_CPU_FALLBACK:-${ALLOW_CPU_FALLBACK_DEFAULT}}")"

if [ "${IS_BLACKWELL}" = "true" ]; then
  echo "ℹ️  Blackwell GPU detected (compute capability ${GPU_COMPUTE_CAPABILITY})."
  echo "ℹ️  Using GPU compatibility retries; CPU fallback is ${ALLOW_CPU_FALLBACK} (override with MWW_ALLOW_CPU_FALLBACK=true|false)."
fi

# Enable driver-side PTX JIT fallback when ptxas/nvlink are unavailable.
if [ -z "${XLA_FLAGS:-}" ]; then
  export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"
  echo "ℹ️  Setting XLA_FLAGS=${XLA_FLAGS}"
else
  echo "ℹ️  Using user-provided XLA_FLAGS=${XLA_FLAGS}"
fi

check_directories() {
    for d in "$@" ; do
        [ -d "$d" ] || { echo "ERROR: Directory $d not found" >&2 ; exit 1 ; }
    done
}

check_directories ${WORK_DIR}/wake_word_samples_augmented \
    ${TRAINING_DS}/negative_datasets/{speech,dinner_party,no_speech,dinner_party_eval}

# Personal features are optional, but if present they MUST have /training
PERSONAL_FEATURES_DIR="${WORK_DIR}/personal_augmented_features"
HAS_PERSONAL="false"
if [ -d "${PERSONAL_FEATURES_DIR}/training" ] ; then
  HAS_PERSONAL="true"
  echo "✅ Found personal features: ${PERSONAL_FEATURES_DIR}/training (will weight sampling_weight=3.0)"
else
  echo "ℹ️  No personal features found at ${PERSONAL_FEATURES_DIR}/training (continuing without personal weighting)"
fi

# Reviewed false-positive features are optional hard negatives.
REVIEWED_NEGATIVE_FEATURES_DIR="${WORK_DIR}/reviewed_negative_features"
HAS_REVIEWED_NEGATIVE="false"
if [ -d "${REVIEWED_NEGATIVE_FEATURES_DIR}/training" ] ; then
  HAS_REVIEWED_NEGATIVE="true"
  echo "✅ Found reviewed negative features: ${REVIEWED_NEGATIVE_FEATURES_DIR}/training (will weight as hard negatives)"
else
  echo "ℹ️  No reviewed negative features found at ${REVIEWED_NEGATIVE_FEATURES_DIR}/training (continuing with stock negatives)"
fi

cd "${WORK_DIR}"

echo "===== Starting ${TRAINING_STEPS} training steps ====="
START_TS=$EPOCHSECONDS

mkdir -p "${WORK_DIR}/trained_models" || :

# We write a YAML with a marker, then splice personal feature block in if it exists.
YAML_PATH="${WORK_DIR}/trained_models/training_parameters.yaml"

cat <<'EOF' > "${YAML_PATH}"
batch_size: 16
clip_duration_ms: 1500
eval_step_interval: 500
features:
- features_dir: __WAKEWORD_FEATURES__
  penalty_weight: 1.0
  sampling_weight: 2.0
  truncation_strategy: truncate_start
  truth: true
  type: mmap
__PERSONAL_FEATURE_MARKER__
__REVIEWED_NEGATIVE_FEATURE_MARKER__
- features_dir: __NEG_SPEECH__
  penalty_weight: 1.0
  sampling_weight: 12.0
  truncation_strategy: random
  truth: false
  type: mmap
- features_dir: __NEG_DINNER__
  penalty_weight: 1.0
  sampling_weight: 12.0
  truncation_strategy: random
  truth: false
  type: mmap
- features_dir: __NEG_NOSPEECH__
  penalty_weight: 1.0
  sampling_weight: 5.0
  truncation_strategy: random
  truth: false
  type: mmap
- features_dir: __NEG_DINNER_EVAL__
  penalty_weight: 1.0
  sampling_weight: 0.0
  truncation_strategy: split
  truth: false
  type: mmap
freq_mask_count:
- 0
freq_mask_max_size:
- 0
learning_rates:
- 0.001
maximization_metric: average_viable_recall
minimization_metric: null
negative_class_weight:
- 20
positive_class_weight:
- 1
target_minimization: 0.9
time_mask_count:
- 0
time_mask_max_size:
- 0
train_dir: __TRAIN_DIR__
training_steps:
- __TRAINING_STEPS__
window_step_ms: 10
EOF

# Replace placeholders
sed -i \
  -e "s|__WAKEWORD_FEATURES__|${WORK_DIR}/wake_word_samples_augmented|g" \
  -e "s|__NEG_SPEECH__|${TRAINING_DS}/negative_datasets/speech|g" \
  -e "s|__NEG_DINNER__|${TRAINING_DS}/negative_datasets/dinner_party|g" \
  -e "s|__NEG_NOSPEECH__|${TRAINING_DS}/negative_datasets/no_speech|g" \
  -e "s|__NEG_DINNER_EVAL__|${TRAINING_DS}/negative_datasets/dinner_party_eval|g" \
  -e "s|__TRAIN_DIR__|${WORK_DIR}/trained_models/wakeword|g" \
  -e "s|__TRAINING_STEPS__|${TRAINING_STEPS}|g" \
  "${YAML_PATH}"

# Insert/remove personal block
if [ "${HAS_PERSONAL}" = "true" ]; then
  # Insert directly after the wakeword feature block
  personal_block="$(cat <<EOF
- features_dir: ${PERSONAL_FEATURES_DIR}
  penalty_weight: 1.0
  sampling_weight: 3.0
  truncation_strategy: truncate_start
  truth: true
  type: mmap
EOF
)"
  perl -0777 -i -pe "s#__PERSONAL_FEATURE_MARKER__#${personal_block}#g" "${YAML_PATH}"
else
  sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}"
fi

# Insert/remove reviewed hard-negative block
if [ "${HAS_REVIEWED_NEGATIVE}" = "true" ]; then
  reviewed_negative_block="$(cat <<EOF
- features_dir: ${REVIEWED_NEGATIVE_FEATURES_DIR}
  penalty_weight: 1.25
  sampling_weight: 8.0
  truncation_strategy: random
  truth: false
  type: mmap
EOF
)"
  perl -0777 -i -pe "s#__REVIEWED_NEGATIVE_FEATURE_MARKER__#${reviewed_negative_block}#g" "${YAML_PATH}"
else
  sed -i -e "/__REVIEWED_NEGATIVE_FEATURE_MARKER__/d" "${YAML_PATH}"
fi

echo "   Wrote training_parameters.yaml"
rm -rf "${WORK_DIR}/trained_models/wakeword"

wake_word_filename="$(
  echo "${WAKE_WORD}" \
    | tr '[:upper:]' '[:lower:]' \
    | sed -E 's/[^a-z0-9]+/_/g; s/^_+//; s/_+$//'
)"
[ -n "${wake_word_filename}" ] || wake_word_filename="wakeword"

OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
mkdir -p "${OUTPUT_DIR}/logs" || :
TRAIN_LOG="${OUTPUT_DIR}/logs/training.log"

TRAIN_ARGS=(
  -m microwakeword.model_train_eval
  --training_config "${WORK_DIR}/trained_models/training_parameters.yaml"
  --train 1
  --restore_checkpoint 1
  --test_tf_nonstreaming 0
  --test_tflite_nonstreaming 0
  --test_tflite_nonstreaming_quantized 0
  --test_tflite_streaming 0
  --test_tflite_streaming_quantized 1
  --use_weights best_weights
  mixednet
  --pointwise_filters "64,64,64,64"
  --repeat_in_block "1,1,1,1"
  --mixconv_kernel_sizes "[5], [7,11], [9,15], [23]"
  --residual_connection "0,0,0,0"
  --first_conv_filters 32
  --first_conv_kernel_size 5
  --stride 2
)

GPU_FALLBACK_MARKERS=(
  "resourceexhaustederror"
  "resource exhausted"
  "oom"
  "out of memory"
  "cuda_error_out_of_memory"
  "cuda_error_invalid_handle"
  "culaunchkernel"
  "no ptx compilation provider is available"
  "couldn't find a suitable version of ptxas"
  "couldn't find a suitable version of nvlink"
  "failed to allocate"
  "cudnn"
  "cublas"
  "internalerror: cuda"
  "failed call to cuinit"
  "dst tensor is not initialized"
  "failed copying input tensor"
  "_eagerconst"
)

run_attempt() {
  local label="$1"
  shift
  echo
  echo "================================================================================"
  echo "===== ${label} ====="
  echo "================================================================================"
  echo "→ ${PYTHON_BIN:-python} ${TRAIN_ARGS[*]}"
  echo

  "${PYTHON_BIN:-python}" "${TRAIN_ARGS[@]}" 2>&1 \
    | tr '\r' '\n' \
    | stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" \
    | tee "${TRAIN_LOG}" \
    | sed -r -e "/^Validation Batch/d" -e "s/^INFO:absl:/   /g"

  return ${PIPESTATUS[0]}
}

is_gpu_runtime_failure() {
  local log_lc m
  log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)"

  for m in "${GPU_FALLBACK_MARKERS[@]}"; do
    if echo "${log_lc}" | grep -qF "${m}"; then
      return 0
    fi
  done

  # Catch unlisted TF GPU runtime failures (common on newer architectures).
  if echo "${log_lc}" | grep -qF "device:gpu:0" \
      && echo "${log_lc}" | grep -qF "internalerror"; then
    return 0
  fi

  return 1
}

# --------- ENV (keep compatible; DO NOT add unsupported XLA flags) ----------
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"

export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}"
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
if [ "${IS_BLACKWELL}" = "true" ]; then
  # TF 2.20 + Blackwell is often unstable with cuda_malloc_async.
  unset TF_GPU_ALLOCATOR
else
  export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
fi

TRAINING_DONE="false"

echo "🏋️ Starting model training and TFLite export (this is the longest stage)…"
if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then
  echo "✅ Training complete (GPU path)."
  TRAINING_DONE="true"
else
  echo "⚠️  GPU attempt failed. Checking whether this looks like a GPU/OOM/runtime failure…"

  if ! is_gpu_runtime_failure; then
    echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2
    exit 1
  fi

  if [ "${IS_BLACKWELL}" = "true" ]; then
    echo "↪️  Retrying on GPU with Blackwell compatibility profile (BFC allocator + driver PTX fallback)."

    unset TF_GPU_ALLOCATOR
    export TF_XLA_FLAGS="${ORIG_TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
    if run_attempt "Attempt 2/3: GPU training (Blackwell compatibility profile)" ; then
      echo "✅ Training complete (GPU Blackwell compatibility profile)."
      TRAINING_DONE="true"
    else
      if is_gpu_runtime_failure; then
        echo "↪️  Retrying on GPU with minimal runtime knobs (no TF_XLA_FLAGS)."

        unset TF_GPU_ALLOCATOR
        unset TF_XLA_FLAGS
        if run_attempt "Attempt 3/3: GPU training (Blackwell minimal runtime profile)" ; then
          echo "✅ Training complete (GPU Blackwell minimal profile)."
          TRAINING_DONE="true"
        fi
      fi
    fi
  fi
fi

if [ "${TRAINING_DONE}" != "true" ]; then
  if ! is_gpu_runtime_failure; then
    echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2
    exit 1
  fi

  if [ "${ALLOW_CPU_FALLBACK}" = "true" ]; then
    echo "↪️  Detected GPU runtime failure markers. Falling back to CPU (MWW_ALLOW_CPU_FALLBACK=true)."

    export CUDA_VISIBLE_DEVICES=""
    unset TF_GPU_ALLOCATOR
    unset TF_XLA_FLAGS

    # CPU attempt should not inherit GPU-specific XLA flags.
    if [ -n "${ORIG_XLA_FLAGS}" ]; then
      export XLA_FLAGS="${ORIG_XLA_FLAGS}"
    else
      unset XLA_FLAGS
    fi

    if run_attempt "CPU fallback: training (CUDA_VISIBLE_DEVICES='')" ; then
      echo "✅ Training complete (CPU fallback)."
    else
      echo "❌ Training failed on both GPU retries and CPU fallback. See: ${TRAIN_LOG}" >&2
      exit 1
    fi
  else
    echo "❌ GPU training failed after compatibility retries. CPU fallback is disabled." >&2
    echo "   To allow CPU fallback, set MWW_ALLOW_CPU_FALLBACK=true." >&2
    echo "   See: ${TRAIN_LOG}" >&2
    exit 1
  fi
fi

source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
calibration_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/detection_calibration.json"

if [ ! -f "${source_path}" ] ; then
    echo "Output model not found! Training didn't complete successfully.  See ${TRAIN_LOG}"
    exit 1
fi

echo "🎯 Calibrating detector settings for on-device use…"
if "${PYTHON_BIN:-python}" "${PROGDIR}/calibrate_detector.py" \
    --training-config "${WORK_DIR}/trained_models/wakeword/training_config.yaml" \
    --model "${source_path}" \
    --output "${calibration_path}"; then
    echo "✅ Detector calibration complete."
else
    echo "⚠️  Detector calibration failed; packaging with default detector settings."
    rm -f "${calibration_path}" || :
fi

cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :

echo -e "\n   Training complete!"
echo "   Full log: ${TRAIN_LOG}"

tflite_filename="${wake_word_filename}.tflite"
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
cp "${source_path}" "${tflite_path}"

json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
export WAKE_WORD_TITLE LANGUAGE JSON_PATH="${json_path}" TFLITE_FILENAME="${tflite_filename}" CALIBRATION_PATH="${calibration_path}"
echo "📦 Packaging final model artifacts…"
"${PYTHON_BIN:-python}" - <<'PY'
import json
import os
from pathlib import Path

json_path = Path(os.environ["JSON_PATH"])
calibration_path = Path(os.environ.get("CALIBRATION_PATH", ""))
language = (os.environ.get("LANGUAGE", "en") or "en").strip().lower()
probability_cutoff = 0.97
sliding_window_size = 5

if calibration_path.exists():
    try:
        calibration = json.loads(calibration_path.read_text(encoding="utf-8"))
        probability_cutoff = float(calibration.get("probability_cutoff", probability_cutoff))
        sliding_window_size = int(calibration.get("sliding_window_size", sliding_window_size))
        print(
            f"🎯 Using calibrated detector settings: "
            f"cutoff={probability_cutoff:.2f}, window={sliding_window_size}"
        )
    except Exception as exc:
        print(f"⚠️ Failed to read detector calibration ({exc}); using defaults.")

meta = {
    "type": "micro",
    "wake_word": os.environ["WAKE_WORD_TITLE"],
    "author": "Tater Totterson",
    "website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
    "model": os.environ["TFLITE_FILENAME"],
    "trained_languages": [language],
    "version": 2,
    "micro": {
        "probability_cutoff": round(probability_cutoff, 2),
        "sliding_window_size": sliding_window_size,
        "feature_step_size": 10,
        "tensor_arena_size": 30000,
        "minimum_esphome_version": "2024.7.0",
    },
}
json_path.write_text(json.dumps(meta, indent=4) + "\n", encoding="utf-8")
PY

echo "Name:     ${WAKE_WORD_TITLE}"
echo "Model:    ${tflite_path}"
echo "Metadata: ${json_path}"
echo
END_TS=$EPOCHSECONDS
print_elapsed_time "${START_TS}" "${END_TS}" "Training completed."
echo
