mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Train from the command line
The files in the `cli` directory allow you to train wake words from the command line without needing to use the Jupyter notebook or a web browser. Basically, the logic from the notebook has been placed in separate shell scripts and python files wrapped by 3 high-level scripts that do the following: * setup_python_venv: Creates a Python virtual environment with all the packages needed to train. The venv is created in the container's /data directory and is therefore stored on the host, not in the container's root docker volume. * setup_training_datasets: Downloads, extracts and converts the MIT RIR, FMA, Audioset and Negative training reference datasets. Also stored in /data. * train_wake_word: Generates the wake word samples, augments them with the audio from the training datasets, and finally runs the microwakeword training. The resulting model tflite and json files are placed in the /data/output directory. See the README.md file for much more information.
This commit is contained in:
241
cli/wake_word_sample_trainer
Executable file
241
cli/wake_word_sample_trainer
Executable file
@@ -0,0 +1,241 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
PROGPATH=$(realpath "$0")
|
||||
PROGDIR=$(dirname "${PROGPATH}")
|
||||
|
||||
KNOWN_ARGS=( training-steps samples data-dir )
|
||||
source "${PROGDIR}/shell.functions"
|
||||
WAKE_WORD="${POSITIONAL_ARGS[0]}"
|
||||
|
||||
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||||
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||||
HELP=true
|
||||
fi
|
||||
|
||||
if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
|
||||
cat <<EOF >&2
|
||||
Usage: $0 [ --samples=<samples> ] [ --training-steps=<steps> ]
|
||||
<wake_word> [ <wake_word_title> ]
|
||||
|
||||
$0 -h/--help
|
||||
|
||||
--samples: The number of samples to generate for the wake word.
|
||||
Used only to generate output file names.
|
||||
|
||||
--training-steps: Number of training steps.
|
||||
Default: ${DEFAULT_TRAINING_STEPS}
|
||||
|
||||
<wake_word>: The word to train spelled phonetically.
|
||||
Required.
|
||||
|
||||
<wake_word_title>: A pretty name to save to the json metadata file.
|
||||
Default: The wake word with individual words capitalized.
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WORK_DIR="${DATA_DIR}/work"
|
||||
TRAINING_DS="${DATA_DIR}/training_datasets"
|
||||
|
||||
[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}"
|
||||
|
||||
if [ ! -v WAKE_WORD_TITLE ] ; then
|
||||
declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
|
||||
WAKE_WORD_TITLE="${WWNA[*]^}"
|
||||
elif [ -z "$WAKE_WORD_TITLE" ] ; then
|
||||
WAKE_WORD_TITLE="$WAKE_WORD"
|
||||
fi
|
||||
|
||||
# shellcheck source=/dev/null
|
||||
source "${DATA_DIR}/.venv/bin/activate"
|
||||
|
||||
check_directories() {
|
||||
for d in "$@" ; do
|
||||
[ -d "$d" ] || { echo "ERROR: Directory $d not found" >&2 ; exit 1 ; }
|
||||
done
|
||||
}
|
||||
|
||||
check_directories ${WORK_DIR}/wake_word_samples_augmented \
|
||||
${TRAINING_DS}/negative_datasets/{speech,dinner_party,no_speech,dinner_party_eval}
|
||||
|
||||
cd "${WORK_DIR}"
|
||||
|
||||
echo "===== Starting ${TRAINING_STEPS} training steps ====="
|
||||
|
||||
START_TS=$EPOCHSECONDS
|
||||
|
||||
mkdir -p "${WORK_DIR}/trained_models" || :
|
||||
cat <<EOF >"${WORK_DIR}/trained_models/training_parameters.yaml"
|
||||
batch_size: 16
|
||||
clip_duration_ms: 1500
|
||||
eval_step_interval: 500
|
||||
features:
|
||||
- features_dir: ${WORK_DIR}/wake_word_samples_augmented
|
||||
penalty_weight: 1.0
|
||||
sampling_weight: 2.0
|
||||
truncation_strategy: truncate_start
|
||||
truth: true
|
||||
type: mmap
|
||||
- features_dir: ${TRAINING_DS}/negative_datasets/speech
|
||||
penalty_weight: 1.0
|
||||
sampling_weight: 12.0
|
||||
truncation_strategy: random
|
||||
truth: false
|
||||
type: mmap
|
||||
- features_dir: ${TRAINING_DS}/negative_datasets/dinner_party
|
||||
penalty_weight: 1.0
|
||||
sampling_weight: 12.0
|
||||
truncation_strategy: random
|
||||
truth: false
|
||||
type: mmap
|
||||
- features_dir: ${TRAINING_DS}/negative_datasets/no_speech
|
||||
penalty_weight: 1.0
|
||||
sampling_weight: 5.0
|
||||
truncation_strategy: random
|
||||
truth: false
|
||||
type: mmap
|
||||
- features_dir: ${TRAINING_DS}/negative_datasets/dinner_party_eval
|
||||
penalty_weight: 1.0
|
||||
sampling_weight: 0.0
|
||||
truncation_strategy: split
|
||||
truth: false
|
||||
type: mmap
|
||||
freq_mask_count:
|
||||
- 0
|
||||
freq_mask_max_size:
|
||||
- 0
|
||||
learning_rates:
|
||||
- 0.001
|
||||
maximization_metric: average_viable_recall
|
||||
minimization_metric: null
|
||||
negative_class_weight:
|
||||
- 20
|
||||
positive_class_weight:
|
||||
- 1
|
||||
target_minimization: 0.9
|
||||
time_mask_count:
|
||||
- 0
|
||||
time_mask_max_size:
|
||||
- 0
|
||||
train_dir: ${WORK_DIR}/trained_models/wakeword
|
||||
training_steps:
|
||||
- ${TRAINING_STEPS}
|
||||
window_step_ms: 10
|
||||
|
||||
EOF
|
||||
|
||||
echo " Wrote training_parameters.yaml"
|
||||
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
||||
|
||||
export TF_CPP_MIN_LOG_LEVEL=9
|
||||
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
||||
export TF_GPU_ALLOCATOR=cuda_malloc_async
|
||||
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
|
||||
export NVIDIA_TF32_OVERRIDE=1
|
||||
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
|
||||
export GLOG_minloglevel=9
|
||||
export GRPC_VERBOSITY=ERROR
|
||||
|
||||
echo " Loading Tensorflow"
|
||||
|
||||
wake_word_filename="${WAKE_WORD//[ \`~\!\$&*\(\)\{\}\[\]\|\;\'\"<>.?\/]/_}"
|
||||
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
|
||||
mkdir -p "${OUTPUT_DIR}/logs" || :
|
||||
|
||||
python - \
|
||||
--training_config="${WORK_DIR}/trained_models/training_parameters.yaml" \
|
||||
--train 1 \
|
||||
--restore_checkpoint 1 \
|
||||
--test_tf_nonstreaming 0 \
|
||||
--test_tflite_nonstreaming 0 \
|
||||
--test_tflite_nonstreaming_quantized 0 \
|
||||
--test_tflite_streaming 0 \
|
||||
--test_tflite_streaming_quantized 1 \
|
||||
--use_weights "best_weights" \
|
||||
mixednet \
|
||||
--pointwise_filters "64,64,64,64" \
|
||||
--repeat_in_block "1,1,1,1" \
|
||||
--mixconv_kernel_sizes "[5], [7,11], [9,15], [23]" \
|
||||
--residual_connection "0,0,0,0" \
|
||||
--first_conv_filters 32 \
|
||||
--first_conv_kernel_size 5 \
|
||||
--stride 2 <<EOF 2>&1 | tr '\r' '\n' | stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" |\
|
||||
tee "${OUTPUT_DIR}/logs/training.log" | sed -r -e '/^INFO:absl:/!d' \
|
||||
-r -e "/None|Sharding|unsupported characters|AUC|fingerprint/d" \
|
||||
-r -e 's/INFO:absl:/ /g' \
|
||||
-r -e "s/, (recall =|estimated false|average viable recall)/,\n \1/g"
|
||||
|
||||
import sys, os, gc
|
||||
import runpy
|
||||
import yaml
|
||||
print(" Loading Tensorflow")
|
||||
import tensorflow as tf
|
||||
|
||||
print(" GPU memory config")
|
||||
# Per-device memory growth (belt + suspenders)
|
||||
for g in tf.config.list_physical_devices("GPU"):
|
||||
try:
|
||||
tf.config.experimental.set_memory_growth(g, True)
|
||||
except Exception:
|
||||
pass
|
||||
print(f"INFO:absl:GPUs: {tf.config.list_physical_devices('GPU')}")
|
||||
gc.collect()
|
||||
|
||||
print()
|
||||
try:
|
||||
runpy.run_module("microwakeword.model_train_eval", run_name="__main__", alter_sys=True)
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
EOF
|
||||
|
||||
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||
|
||||
if [ ! -f "${source_path}" ] ; then
|
||||
echo "Output model not found! Training didn't complete successfully. See ${WORK_DIR}/training.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/"
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/"
|
||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/"
|
||||
|
||||
echo -e "\n Training complete!"
|
||||
echo " Full log: ${OUTPUT_DIR}/logs/training.log"
|
||||
|
||||
tflite_filename="${wake_word_filename}.tflite"
|
||||
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||
|
||||
cp "${source_path}" "${tflite_path}"
|
||||
|
||||
# --- Write JSON metadata file with matching model name ---
|
||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||
cat <<-EOF > "${json_path}"
|
||||
{
|
||||
"type": "micro",
|
||||
"wake_word": "${WAKE_WORD_TITLE}",
|
||||
"author": "Tater Totterson",
|
||||
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
||||
"model": "${tflite_filename}",
|
||||
"trained_languages": ["en"],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"sliding_window_size": 5,
|
||||
"feature_step_size": 10,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
}
|
||||
}
|
||||
EOF
|
||||
|
||||
echo "Name: ${WAKE_WORD_TITLE}"
|
||||
echo "Model: ${tflite_path}"
|
||||
echo "Metadata: ${json_path}"
|
||||
echo
|
||||
END_TS=$EPOCHSECONDS
|
||||
print_elapsed_time "${START_TS}" "${END_TS}" "Training completed."
|
||||
echo
|
||||
|
||||
Reference in New Issue
Block a user