mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
169 lines
5.8 KiB
Bash
Executable File
169 lines
5.8 KiB
Bash
Executable File
#!/bin/bash
|
||
set -e
|
||
|
||
PROGPATH=$(realpath "$0")
|
||
PROGDIR=$(dirname "${PROGPATH}")
|
||
CLIDIR="${PROGDIR}/cli"
|
||
|
||
KNOWN_ARGS=( samples batch-size training-steps data-dir cleanup-work-dir language )
|
||
source "${CLIDIR}/shell.functions"
|
||
WAKE_WORD=${POSITIONAL_ARGS[0]}
|
||
|
||
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||
HELP=true
|
||
fi
|
||
|
||
if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
|
||
cat <<EOF >&2
|
||
Usage: train_wake_word [ --samples=<samples> ] [ --batch-size=<batch_size> ]
|
||
[ --training-steps=<steps> ] [ --cleanup-work-dir ]
|
||
[ --language=<lang> ]
|
||
<wake_word> [ <wake_word_title> ]
|
||
|
||
Options:
|
||
--samples: The number of samples to generate for the wake word.
|
||
Default: ${DEFAULT_SAMPLES}
|
||
|
||
--batch-size: How many samples should be generated at a time. The more
|
||
samples per batch, the more memory is needed.
|
||
Default: ${DEFAULT_BATCH_SIZE}
|
||
|
||
--training-steps: Number of training steps. More training steps means better
|
||
detection and false positive rates but also more time to train.
|
||
Default: ${DEFAULT_TRAINING_STEPS}
|
||
|
||
--cleanup-work-dir: Delete the /data/work directory after successful training.
|
||
Default: false
|
||
|
||
--language: Language for TTS voice selection (e.g. "en", "nl").
|
||
Default: ${DEFAULT_LANGUAGE}
|
||
|
||
<wake_word> The word to train spelled phonetically.
|
||
Required.
|
||
|
||
<wake_word_title> An optional pretty name to save to the json metadata file.
|
||
Default: The wake word with individual words capitalized
|
||
and punctuation removed.
|
||
|
||
EOF
|
||
exit 1
|
||
fi
|
||
|
||
# shellcheck source=/dev/null
|
||
source "${DATA_DIR}/.venv/bin/activate"
|
||
|
||
cd "${DATA_DIR}"
|
||
mkdir -p "${DATA_DIR}/work" || :
|
||
|
||
[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}" || :
|
||
|
||
if [ ! -v WAKE_WORD_TITLE ] ; then
|
||
declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
|
||
WAKE_WORD_TITLE="${WWNA[*]^}"
|
||
elif [ -z "$WAKE_WORD_TITLE" ] ; then
|
||
WAKE_WORD_TITLE="$WAKE_WORD"
|
||
fi
|
||
|
||
printf "%-80s\n" "=" | tr ' ' "="
|
||
echo "===== Running '${WAKE_WORD}(${WAKE_WORD_TITLE})' generation, augmentation and training ====="
|
||
"${CLIDIR}/cudainfo"
|
||
echo
|
||
START_TS=$EPOCHSECONDS
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# TensorFlow / XLA environment (known-good, portable)
|
||
# -----------------------------------------------------------------------------
|
||
detect_gpu_compute_capability() {
|
||
if command -v nvidia-smi >/dev/null 2>&1 ; then
|
||
nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
|
||
| head -n 1 \
|
||
| tr -d '[:space:]'
|
||
fi
|
||
}
|
||
|
||
GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)"
|
||
IS_BLACKWELL=false
|
||
case "${GPU_COMPUTE_CAPABILITY}" in
|
||
12.*) IS_BLACKWELL=true ;;
|
||
esac
|
||
|
||
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-9}"
|
||
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
|
||
|
||
# Hard-set TF_XLA_FLAGS to ONLY what we know this build supports.
|
||
# Do NOT append user environment flags (can cause hard failures).
|
||
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
|
||
|
||
if ${IS_BLACKWELL} ; then
|
||
# TF 2.20 + Blackwell is often unstable with cuda_malloc_async.
|
||
unset TF_GPU_ALLOCATOR
|
||
echo "ℹ️ Blackwell detected (compute capability ${GPU_COMPUTE_CAPABILITY}): using compatibility GPU defaults."
|
||
else
|
||
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
|
||
fi
|
||
|
||
# Enable driver-side PTX JIT fallback when ptxas/nvlink are unavailable.
|
||
[ -z "${XLA_FLAGS:-}" ] && export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"
|
||
|
||
export NVIDIA_TF32_OVERRIDE=1
|
||
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
|
||
export GLOG_minloglevel=2
|
||
export GRPC_VERBOSITY=ERROR
|
||
# -----------------------------------------------------------------------------
|
||
|
||
"${CLIDIR}/wake_word_sample_generator" \
|
||
--samples=${SAMPLES} \
|
||
--batch-size=${BATCH_SIZE} \
|
||
--language="${LANGUAGE}" \
|
||
--data-dir="${DATA_DIR}" "${WAKE_WORD}"
|
||
|
||
POST_GEN_TS=$EPOCHSECONDS
|
||
|
||
AUGMENT=false
|
||
GENERATED_DIR="${DATA_DIR}/work/wake_word_samples"
|
||
AUGMENTED_DIR="${DATA_DIR}/work/wake_word_samples_augmented"
|
||
|
||
[ -d "${AUGMENTED_DIR}" ] || AUGMENT=true
|
||
[ "${GENERATED_DIR}/0.wav" -nt "${AUGMENTED_DIR}/testing/wakeword_mmap/data.ninja" ] && AUGMENT=true || :
|
||
|
||
if ${AUGMENT} ; then
|
||
rm -rf "${AUGMENTED_DIR}" || :
|
||
mkdir -p "${AUGMENTED_DIR}" || :
|
||
python -u "${CLIDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
|
||
else
|
||
echo "Augmentation not required"
|
||
echo
|
||
fi
|
||
|
||
POST_AUGMENT_TS=$EPOCHSECONDS
|
||
|
||
"${CLIDIR}/wake_word_sample_trainer" \
|
||
--samples=${SAMPLES} \
|
||
--training-steps=${TRAINING_STEPS} \
|
||
--data-dir="${DATA_DIR}" \
|
||
"${WAKE_WORD}" "${WAKE_WORD_TITLE}"
|
||
|
||
if ${CLEANUP_WORK_DIR} ; then
|
||
rm -rf \
|
||
"${DATA_DIR}/work/trained_models" \
|
||
"${DATA_DIR}/work/wake_word_samples" \
|
||
"${DATA_DIR}/work/wake_word_samples_augmented" \
|
||
"${DATA_DIR}/work/personal_augmented_features" \
|
||
"${DATA_DIR}/work/reviewed_negative_features" \
|
||
"${DATA_DIR}/work/last_wake_word" || :
|
||
fi
|
||
|
||
END_TS=$EPOCHSECONDS
|
||
|
||
python -c $'print(f"{\'=\' * 80}")'
|
||
printf "%44s\n\n" "Training Summary"
|
||
"${CLIDIR}/system_summary"
|
||
echo
|
||
print_elapsed_time --no-separators "${START_TS}" "${POST_GEN_TS}" "Generate ${SAMPLES} samples, ${BATCH_SIZE}/batch"
|
||
print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augment ${SAMPLES} samples"
|
||
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
|
||
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
|
||
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
|
||
python -c $'print(f"{\'=\' * 80}")'
|