#!/bin/bash
set -e

PROGPATH=$(realpath "$0")
PROGDIR=$(dirname "${PROGPATH}")

KNOWN_ARGS=( training-steps samples data-dir )
source "${PROGDIR}/shell.functions"
WAKE_WORD="${POSITIONAL_ARGS[0]}"

if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
    echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
    HELP=true
fi

if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
    cat <<EOF >&2
Usage: $0 [ --samples=<samples> ] [ --training-steps=<steps> ]
          <wake_word> [ <wake_word_title> ]

       $0 -h/--help

--samples:          The number of samples to generate for the wake word.
                    Used only to generate output file names.

--training-steps:   Number of training steps.
                    Default: ${DEFAULT_TRAINING_STEPS}

<wake_word>:        The word to train spelled phonetically.
                    Required.

<wake_word_title>:  A pretty name to save to the json metadata file.
                    Default: The wake word with individual words capitalized.

EOF
    exit 1
fi

WORK_DIR="${DATA_DIR}/work"
TRAINING_DS="${DATA_DIR}/training_datasets"

[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}"

if [ ! -v WAKE_WORD_TITLE ] ; then
    declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
    WAKE_WORD_TITLE="${WWNA[*]^}"
elif [ -z "$WAKE_WORD_TITLE" ] ; then
    WAKE_WORD_TITLE="$WAKE_WORD"
fi

# shellcheck source=/dev/null
source "${DATA_DIR}/.venv/bin/activate"

check_directories() {
    for d in "$@" ; do
        [ -d "$d" ] || { echo "ERROR: Directory $d not found" >&2 ; exit 1 ; }
    done
}

check_directories ${WORK_DIR}/wake_word_samples_augmented \
    ${TRAINING_DS}/negative_datasets/{speech,dinner_party,no_speech,dinner_party_eval}

cd "${WORK_DIR}"

echo "===== Starting ${TRAINING_STEPS} training steps ====="

START_TS=$EPOCHSECONDS

mkdir -p "${WORK_DIR}/trained_models" || :
cat <<EOF >"${WORK_DIR}/trained_models/training_parameters.yaml"
batch_size: 16
clip_duration_ms: 1500
eval_step_interval: 500
features:
- features_dir: ${WORK_DIR}/wake_word_samples_augmented
  penalty_weight: 1.0
  sampling_weight: 2.0
  truncation_strategy: truncate_start
  truth: true
  type: mmap
- features_dir: ${TRAINING_DS}/negative_datasets/speech
  penalty_weight: 1.0
  sampling_weight: 12.0
  truncation_strategy: random
  truth: false
  type: mmap
- features_dir: ${TRAINING_DS}/negative_datasets/dinner_party
  penalty_weight: 1.0
  sampling_weight: 12.0
  truncation_strategy: random
  truth: false
  type: mmap
- features_dir: ${TRAINING_DS}/negative_datasets/no_speech
  penalty_weight: 1.0
  sampling_weight: 5.0
  truncation_strategy: random
  truth: false
  type: mmap
- features_dir: ${TRAINING_DS}/negative_datasets/dinner_party_eval
  penalty_weight: 1.0
  sampling_weight: 0.0
  truncation_strategy: split
  truth: false
  type: mmap
freq_mask_count:
- 0
freq_mask_max_size:
- 0
learning_rates:
- 0.001
maximization_metric: average_viable_recall
minimization_metric: null
negative_class_weight:
- 20
positive_class_weight:
- 1
target_minimization: 0.9
time_mask_count:
- 0
time_mask_max_size:
- 0
train_dir: ${WORK_DIR}/trained_models/wakeword
training_steps:
- ${TRAINING_STEPS}
window_step_ms: 10

EOF

echo "   Wrote training_parameters.yaml"
rm -rf "${WORK_DIR}/trained_models/wakeword"

export TF_CPP_MIN_LOG_LEVEL=9
export TF_FORCE_GPU_ALLOW_GROWTH=true
export TF_GPU_ALLOCATOR=cuda_malloc_async
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
export NVIDIA_TF32_OVERRIDE=1
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
export GLOG_minloglevel=9
export GRPC_VERBOSITY=ERROR

echo "   Loading Tensorflow"

wake_word_filename="${WAKE_WORD//[ \`~\!\$&*\(\)\{\}\[\]\|\;\'\"<>.?\/]/_}"
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
mkdir -p "${OUTPUT_DIR}/logs" || :

python - \
  --training_config="${WORK_DIR}/trained_models/training_parameters.yaml" \
  --train 1 \
  --restore_checkpoint 1 \
  --test_tf_nonstreaming 0 \
  --test_tflite_nonstreaming 0 \
  --test_tflite_nonstreaming_quantized 0 \
  --test_tflite_streaming 0 \
  --test_tflite_streaming_quantized 1 \
  --use_weights "best_weights" \
  mixednet \
  --pointwise_filters "64,64,64,64" \
  --repeat_in_block "1,1,1,1" \
  --mixconv_kernel_sizes "[5], [7,11], [9,15], [23]" \
  --residual_connection "0,0,0,0" \
  --first_conv_filters 32 \
  --first_conv_kernel_size 5 \
  --stride 2 <<EOF 2>&1 | tr '\r' '\n' | stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" |\
        tee "${OUTPUT_DIR}/logs/training.log" | sed -r -e '/^INFO:absl:/!d' \
            -r -e "/None|Sharding|unsupported characters|AUC|fingerprint/d" \
            -r -e 's/INFO:absl:/   /g' \
            -r -e "s/, (recall =|estimated false|average viable recall)/,\n      \1/g"

import sys, os, gc
import runpy
import yaml
print("   Loading Tensorflow")
import tensorflow as tf

print("   GPU memory config")
# Per-device memory growth (belt + suspenders)
for g in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        pass
print(f"INFO:absl:GPUs: {tf.config.list_physical_devices('GPU')}")
gc.collect()

print()
try:
    runpy.run_module("microwakeword.model_train_eval", run_name="__main__", alter_sys=True)
except Exception as e:
    print(e, file=sys.stderr)
    sys.exit(1)
EOF

source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"

if [ ! -f "${source_path}" ] ; then
    echo "Output model not found! Training didn't complete successfully.  See ${WORK_DIR}/training.log"
    exit 1
fi

cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/"
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/"
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/"

echo -e "\n   Training complete!"
echo "   Full log: ${OUTPUT_DIR}/logs/training.log"

tflite_filename="${wake_word_filename}.tflite"
tflite_path="${OUTPUT_DIR}/${tflite_filename}"

cp "${source_path}" "${tflite_path}"

# --- Write JSON metadata file with matching model name ---
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
cat <<-EOF > "${json_path}"
{
    "type": "micro",
    "wake_word": "${WAKE_WORD_TITLE}",
    "author": "Tater Totterson",
    "website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
    "model": "${tflite_filename}",
    "trained_languages": ["en"],
    "version": 2,
    "micro": {
        "probability_cutoff": 0.97,
        "sliding_window_size": 5,
        "feature_step_size": 10,
        "tensor_arena_size": 30000,
        "minimum_esphome_version": "2024.7.0"
    }
}
EOF

echo "Name:     ${WAKE_WORD_TITLE}"
echo "Model:    ${tflite_path}"
echo "Metadata: ${json_path}"
echo
END_TS=$EPOCHSECONDS
print_elapsed_time "${START_TS}" "${END_TS}" "Training completed."
echo

