#!/bin/bash
set -e

PROGPATH=$(realpath "$0")
PROGDIR=$(dirname "${PROGPATH}")
CLIDIR="${PROGDIR}/cli"

KNOWN_ARGS=( samples batch-size training-steps data-dir cleanup-work-dir language )
source "${CLIDIR}/shell.functions"
WAKE_WORD=${POSITIONAL_ARGS[0]}

if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
    echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
    HELP=true
fi

if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
    cat <<EOF >&2
Usage: train_wake_word [ --samples=<samples> ] [ --batch-size=<batch_size> ]
                       [ --training-steps=<steps> ] [ --cleanup-work-dir ]
                       [ --language=<lang> ]
                       <wake_word> [ <wake_word_title> ]

Options:
--samples:            The number of samples to generate for the wake word.
                      Default: ${DEFAULT_SAMPLES}

--batch-size:         How many samples should be generated at a time.  The more
                      samples per batch, the more memory is needed.
                      Default: ${DEFAULT_BATCH_SIZE}

--training-steps:     Number of training steps.  More training steps means better
                      detection and false positive rates but also more time to train.
                      Default: ${DEFAULT_TRAINING_STEPS}

--cleanup-work-dir:   Delete the /data/work directory after successful training.
                      Default: false

--language:           Language for TTS voice selection (e.g. "en", "nl").
                      Default: ${DEFAULT_LANGUAGE}

<wake_word>           The word to train spelled phonetically.
                      Required.

<wake_word_title>     An optional pretty name to save to the json metadata file.
                      Default: The wake word with individual words capitalized
                               and punctuation removed.

EOF
    exit 1
fi

# shellcheck source=/dev/null
source "${DATA_DIR}/.venv/bin/activate"

cd "${DATA_DIR}"
mkdir -p "${DATA_DIR}/work" || :

[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}" || :

if [ ! -v WAKE_WORD_TITLE ] ; then
    declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
    WAKE_WORD_TITLE="${WWNA[*]^}"
elif [ -z "$WAKE_WORD_TITLE" ] ; then
    WAKE_WORD_TITLE="$WAKE_WORD"
fi

printf "%-80s\n" "=" | tr ' ' "="
echo "===== Running '${WAKE_WORD}(${WAKE_WORD_TITLE})' generation, augmentation and training ====="
"${CLIDIR}/cudainfo"
echo
START_TS=$EPOCHSECONDS

# -----------------------------------------------------------------------------
# TensorFlow / XLA environment (known-good, portable)
# -----------------------------------------------------------------------------
detect_gpu_compute_capability() {
    if command -v nvidia-smi >/dev/null 2>&1 ; then
        nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
            | head -n 1 \
            | tr -d '[:space:]'
    fi
}

GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)"
IS_BLACKWELL=false
case "${GPU_COMPUTE_CAPABILITY}" in
    12.*) IS_BLACKWELL=true ;;
esac

export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-9}"
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"

# Hard-set TF_XLA_FLAGS to ONLY what we know this build supports.
# Do NOT append user environment flags (can cause hard failures).
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"

if ${IS_BLACKWELL} ; then
    # TF 2.20 + Blackwell is often unstable with cuda_malloc_async.
    unset TF_GPU_ALLOCATOR
    echo "ℹ️  Blackwell detected (compute capability ${GPU_COMPUTE_CAPABILITY}): using compatibility GPU defaults."
else
    export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
fi

# Enable driver-side PTX JIT fallback when ptxas/nvlink are unavailable.
[ -z "${XLA_FLAGS:-}" ] && export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"

export NVIDIA_TF32_OVERRIDE=1
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
export GLOG_minloglevel=2
export GRPC_VERBOSITY=ERROR
# -----------------------------------------------------------------------------

"${CLIDIR}/wake_word_sample_generator" \
    --samples=${SAMPLES} \
    --batch-size=${BATCH_SIZE} \
    --language="${LANGUAGE}" \
    --data-dir="${DATA_DIR}" "${WAKE_WORD}"

POST_GEN_TS=$EPOCHSECONDS

AUGMENT=false
GENERATED_DIR="${DATA_DIR}/work/wake_word_samples"
AUGMENTED_DIR="${DATA_DIR}/work/wake_word_samples_augmented"

[ -d "${AUGMENTED_DIR}" ] || AUGMENT=true
[ "${GENERATED_DIR}/0.wav" -nt "${AUGMENTED_DIR}/testing/wakeword_mmap/data.ninja" ] && AUGMENT=true || :

if ${AUGMENT} ; then
    rm -rf "${AUGMENTED_DIR}" || :
    mkdir -p "${AUGMENTED_DIR}" || :
    python -u "${CLIDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
else
    echo "Augmentation not required"
    echo
fi

POST_AUGMENT_TS=$EPOCHSECONDS

"${CLIDIR}/wake_word_sample_trainer" \
    --samples=${SAMPLES} \
    --training-steps=${TRAINING_STEPS} \
    --data-dir="${DATA_DIR}" \
    "${WAKE_WORD}" "${WAKE_WORD_TITLE}"

if ${CLEANUP_WORK_DIR} ; then
    rm -rf \
      "${DATA_DIR}/work/trained_models" \
      "${DATA_DIR}/work/wake_word_samples" \
      "${DATA_DIR}/work/wake_word_samples_augmented" \
      "${DATA_DIR}/work/personal_augmented_features" \
      "${DATA_DIR}/work/last_wake_word" || :
fi

END_TS=$EPOCHSECONDS

python -c $'print(f"{\'=\' * 80}")'
printf "%44s\n\n" "Training Summary"
"${CLIDIR}/system_summary"
echo
print_elapsed_time --no-separators "${START_TS}" "${POST_GEN_TS}" "Generate ${SAMPLES} samples, ${BATCH_SIZE}/batch"
print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augment ${SAMPLES} samples"
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
python -c $'print(f"{\'=\' * 80}")'
