#!/bin/bash
set -e

PROGPATH=$(realpath "$0")
PROGDIR=$(dirname "${PROGPATH}")

KNOWN_ARGS=( samples batch-size training-steps data-dir cleanup-work-dir )
source "${PROGDIR}/shell.functions"
WAKE_WORD=${POSITIONAL_ARGS[0]}

if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
    echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
    HELP=true
fi

if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
    cat <<EOF >&2
Usage: train_wake_word [ --samples=<samples> ] [ --batch-size=<batch_size> ]
                       [ --training-steps=<steps> ] [ --cleanup-work-dir ]
                       <wake_word> [ <wake_word_title> ]

Options:
--samples:            The number of samples to generate for the wake word.
                      Default: ${DEFAULT_SAMPLES}

--batch-size:         How many samples should be generated at a time.  The more
                      samples per batch, the more memory is needed.
                      Default: ${DEFAULT_BATCH_SIZE}

--training-steps:     Number of training steps.  More training steps means better
                      detection and false positive rates but also more time to train.
                      Default: ${DEFAULT_TRAINING_STEPS}

--cleanup-work-dir:   Delete the /data/work directory after successful training.
                      Default: false

<wake_word>           The word to train spelled phonetically.
                      Required.

<wake_word_title>     An optional pretty name to save to the json metadata file.
                      Default: The wake word with individual words capitalized
                               and punctuation removed.

EOF
    exit 1
fi

# shellcheck source=/dev/null
source "${DATA_DIR}/.venv/bin/activate"

cd "${DATA_DIR}"
mkdir -p "${DATA_DIR}/work" || :

[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}" || :

if [ ! -v WAKE_WORD_TITLE ] ; then
    declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
    WAKE_WORD_TITLE="${WWNA[*]^}"
elif [ -z "$WAKE_WORD_TITLE" ] ; then
    WAKE_WORD_TITLE="$WAKE_WORD"
fi

printf "%-80s\n" "=" | tr ' ' "="
echo "===== Running '${WAKE_WORD}(${WAKE_WORD_TITLE})' generation, augmentation and training ====="
"${PROGDIR}/cudainfo"
echo
START_TS=$EPOCHSECONDS

export TF_CPP_MIN_LOG_LEVEL=9
export TF_FORCE_GPU_ALLOW_GROWTH=true
export TF_GPU_ALLOCATOR=cuda_malloc_async
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
export NVIDIA_TF32_OVERRIDE=1
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
export GLOG_minloglevel=2
export GRPC_VERBOSITY=ERROR


"${PROGDIR}/wake_word_sample_generator" \
    --samples=${SAMPLES} \
    --batch-size=${BATCH_SIZE} \
    --data-dir="${DATA_DIR}" "${WAKE_WORD}"

POST_GEN_TS=$EPOCHSECONDS

ww="${WAKE_WORD// /_}"
ww="${ww//./}"

AUGMENT=false
GENERATED_DIR="${DATA_DIR}/work/wake_word_samples"
AUGMENTED_DIR="${DATA_DIR}/work/wake_word_samples_augmented"

[ -d "${AUGMENTED_DIR}" ] || AUGMENT=true
[ "${GENERATED_DIR}/0.wav" -nt "${AUGMENTED_DIR}/testing/wakeword_mmap/data.ninja" ] && AUGMENT=true || :

if ${AUGMENT} ; then
    rm -rf "${AUGMENTED_DIR}" || :
    mkdir -p "${AUGMENTED_DIR}" || :
    "${PROGDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
else
    echo "Augmentation not required"
    echo
fi

POST_AUGMENT_TS=$EPOCHSECONDS

"${PROGDIR}/wake_word_sample_trainer" --samples=${SAMPLES} --training-steps=${TRAINING_STEPS} --data-dir="${DATA_DIR}" \
        "${WAKE_WORD}" "${WAKE_WORD_TITLE}"

if ${CLEANUP_WORK_DIR} ; then
    rm -rf "${DATA_DIR}/work/trained_models" "${DATA_DIR}/work/wake_word_samples" \
    "${DATA_DIR}/work/wake_word_samples_augmented" "${DATA_DIR}/work/last_wake_word" || :
fi
END_TS=$EPOCHSECONDS

python -c $'print(f"{\'=\' * 80}")'
printf "%44s\n\n" "Training Summary"
"${PROGDIR}/system_summary"
echo
print_elapsed_time --no-separators "${START_TS}" "${POST_GEN_TS}" "Generate ${SAMPLES} samples, ${BATCH_SIZE}/batch"
print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augment ${SAMPLES} samples"
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
python -c $'print(f"{\'=\' * 80}")'
