Files
microWakeWord-Trainer-Nvidi…/train_wake_word
2026-03-09 19:48:35 -05:00

167 lines
5.7 KiB
Bash
Executable File
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/bin/bash
set -e
PROGPATH=$(realpath "$0")
PROGDIR=$(dirname "${PROGPATH}")
CLIDIR="${PROGDIR}/cli"
KNOWN_ARGS=( samples batch-size training-steps data-dir cleanup-work-dir language )
source "${CLIDIR}/shell.functions"
WAKE_WORD=${POSITIONAL_ARGS[0]}
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
HELP=true
fi
if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
cat <<EOF >&2
Usage: train_wake_word [ --samples=<samples> ] [ --batch-size=<batch_size> ]
[ --training-steps=<steps> ] [ --cleanup-work-dir ]
[ --language=<lang> ]
<wake_word> [ <wake_word_title> ]
Options:
--samples: The number of samples to generate for the wake word.
Default: ${DEFAULT_SAMPLES}
--batch-size: How many samples should be generated at a time. The more
samples per batch, the more memory is needed.
Default: ${DEFAULT_BATCH_SIZE}
--training-steps: Number of training steps. More training steps means better
detection and false positive rates but also more time to train.
Default: ${DEFAULT_TRAINING_STEPS}
--cleanup-work-dir: Delete the /data/work directory after successful training.
Default: false
--language: Language for TTS voice selection (e.g. "en", "nl").
Default: ${DEFAULT_LANGUAGE}
<wake_word> The word to train spelled phonetically.
Required.
<wake_word_title> An optional pretty name to save to the json metadata file.
Default: The wake word with individual words capitalized
and punctuation removed.
EOF
exit 1
fi
# shellcheck source=/dev/null
source "${DATA_DIR}/.venv/bin/activate"
cd "${DATA_DIR}"
mkdir -p "${DATA_DIR}/work" || :
[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}" || :
if [ ! -v WAKE_WORD_TITLE ] ; then
declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
WAKE_WORD_TITLE="${WWNA[*]^}"
elif [ -z "$WAKE_WORD_TITLE" ] ; then
WAKE_WORD_TITLE="$WAKE_WORD"
fi
printf "%-80s\n" "=" | tr ' ' "="
echo "===== Running '${WAKE_WORD}(${WAKE_WORD_TITLE})' generation, augmentation and training ====="
"${CLIDIR}/cudainfo"
echo
START_TS=$EPOCHSECONDS
# -----------------------------------------------------------------------------
# TensorFlow / XLA environment (known-good, portable)
# -----------------------------------------------------------------------------
detect_gpu_compute_capability() {
if command -v nvidia-smi >/dev/null 2>&1 ; then
nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
| head -n 1 \
| tr -d '[:space:]'
fi
}
GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)"
IS_BLACKWELL=false
case "${GPU_COMPUTE_CAPABILITY}" in
12.*) IS_BLACKWELL=true ;;
esac
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-9}"
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
# Hard-set TF_XLA_FLAGS to ONLY what we know this build supports.
# Do NOT append user environment flags (can cause hard failures).
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
if ${IS_BLACKWELL} ; then
# TF 2.20 + Blackwell is often unstable with cuda_malloc_async.
unset TF_GPU_ALLOCATOR
[ -z "${XLA_FLAGS:-}" ] && export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"
echo " Blackwell detected (compute capability ${GPU_COMPUTE_CAPABILITY}): using compatibility GPU defaults."
else
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
unset XLA_FLAGS
fi
export NVIDIA_TF32_OVERRIDE=1
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
export GLOG_minloglevel=2
export GRPC_VERBOSITY=ERROR
# -----------------------------------------------------------------------------
"${CLIDIR}/wake_word_sample_generator" \
--samples=${SAMPLES} \
--batch-size=${BATCH_SIZE} \
--language="${LANGUAGE}" \
--data-dir="${DATA_DIR}" "${WAKE_WORD}"
POST_GEN_TS=$EPOCHSECONDS
AUGMENT=false
GENERATED_DIR="${DATA_DIR}/work/wake_word_samples"
AUGMENTED_DIR="${DATA_DIR}/work/wake_word_samples_augmented"
[ -d "${AUGMENTED_DIR}" ] || AUGMENT=true
[ "${GENERATED_DIR}/0.wav" -nt "${AUGMENTED_DIR}/testing/wakeword_mmap/data.ninja" ] && AUGMENT=true || :
if ${AUGMENT} ; then
rm -rf "${AUGMENTED_DIR}" || :
mkdir -p "${AUGMENTED_DIR}" || :
python -u "${CLIDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
else
echo "Augmentation not required"
echo
fi
POST_AUGMENT_TS=$EPOCHSECONDS
"${CLIDIR}/wake_word_sample_trainer" \
--samples=${SAMPLES} \
--training-steps=${TRAINING_STEPS} \
--data-dir="${DATA_DIR}" \
"${WAKE_WORD}" "${WAKE_WORD_TITLE}"
if ${CLEANUP_WORK_DIR} ; then
rm -rf \
"${DATA_DIR}/work/trained_models" \
"${DATA_DIR}/work/wake_word_samples" \
"${DATA_DIR}/work/wake_word_samples_augmented" \
"${DATA_DIR}/work/personal_augmented_features" \
"${DATA_DIR}/work/last_wake_word" || :
fi
END_TS=$EPOCHSECONDS
python -c $'print(f"{\'=\' * 80}")'
printf "%44s\n\n" "Training Summary"
"${CLIDIR}/system_summary"
echo
print_elapsed_time --no-separators "${START_TS}" "${POST_GEN_TS}" "Generate ${SAMPLES} samples, ${BATCH_SIZE}/batch"
print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augment ${SAMPLES} samples"
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
python -c $'print(f"{\'=\' * 80}")'