From 94903783cb00eff108e3f94655d397b5c935e5dc Mon Sep 17 00:00:00 2001 From: MasterPhooey Date: Mon, 9 Mar 2026 19:48:35 -0500 Subject: [PATCH] blackwell/wham & chim datasets --- cli/setup_chime | 142 +++++++++++++++++++++++++++++++++ cli/setup_python_venv | 69 ++++++++++++++-- cli/setup_training_datasets | 10 +++ cli/setup_wham | 142 +++++++++++++++++++++++++++++++++ cli/wake_word_sample_augmenter | 31 ++++++- cli/wake_word_sample_trainer | 131 +++++++++++++++++++++++------- train_wake_word | 34 ++++++-- 7 files changed, 517 insertions(+), 42 deletions(-) create mode 100755 cli/setup_chime create mode 100755 cli/setup_wham diff --git a/cli/setup_chime b/cli/setup_chime new file mode 100755 index 0000000..a1fe6b0 --- /dev/null +++ b/cli/setup_chime @@ -0,0 +1,142 @@ +#!/bin/bash +set -euo pipefail + +PROGPATH=$(realpath "$0") +PROGDIR=$(dirname "${PROGPATH}") + +source "${PROGDIR}/shell.functions" + +if [ "${HELP}" == "true" ] ; then + cat <&2 +Usage: $0 [ --cleanup-archives ] [ --cleanup-input-files ] [ --data-dir= ] + + --cleanup-archives + : Automatically clean up any downloaded archives after + : extraction. + --cleanup-intermediate-files + : Automatically clean up intermediate extracted files + : after conversion to 16k. + : Path to the data directory. + : Default: ${DATA_DIR} + +EOF + exit 1 +fi + +mkdir -p "${DATA_DIR}/training_datasets/downloads" || : +cd "${DATA_DIR}/training_datasets" + +echo "***** Checking CHiME-Home *****" + +AUDIO_URL="https://archive.org/download/chime-home/chime_home.tar.gz" +AUDIO_TARFILE="chime_home.tar.gz" +AUDIO_TAR="./downloads/${AUDIO_TARFILE}" +AUDIO_DIR="./chime" +mkdir -p "${AUDIO_DIR}" || : +AUDIO16K_DIR="./chime_16k" +mkdir -p "${AUDIO16K_DIR}" || : +AUDIO_FILECOUNT="./downloads/chime_filecount" +AUDIO_IN_GLOB="*.48kHz.wav" + +declare -A filecounts=( [${AUDIO_TARFILE}]=0 ) +get_filecounts filecounts "${AUDIO_FILECOUNT}" + +converter() { + source "${DATA_DIR}/.venv/bin/activate" + + python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF +import sys +from pathlib import Path +import numpy as np +import scipy.io.wavfile +import librosa +from tqdm import tqdm + +def write_wav(dst: Path, data: np.ndarray, sr: int): + dst.parent.mkdir(parents=True, exist_ok=True) + x = np.clip(data, -1.0, 1.0) + scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16)) + +def flatten_name(root: Path, src: Path) -> str: + rel = src.relative_to(root) + return "__".join(rel.parts) + +chime_in = Path(sys.argv[1]).resolve() +chime_out = Path(sys.argv[2]).resolve() + +wavs = list(chime_in.rglob("*.48kHz.wav")) +print(f" WAV files: {len(wavs)}") +print(" Converting CHiME -> 16k mono WAV") + +bad = [] +ok = 0 +skipped = 0 +for p in tqdm(wavs, desc=" CHiME -> WAV (resample 16k mono)"): + try: + out_name = flatten_name(chime_in, p) + outfile = chime_out / out_name + if outfile.exists(): + skipped += 1 + continue + y, _ = librosa.load(p, sr=16000, mono=True) + if y.size == 0: + raise ValueError("empty audio") + write_wav(outfile, y, 16000) + ok += 1 + except Exception as e: + bad.append(f"{p}:{e}") + +if bad: + (chime_out / "chime_corrupted_files.log").write_text("\\n".join(bad)) +print(f" CHiME complete ({ok} ok, {skipped} skipped, {len(bad)} failed)") +EOF +} + +expected_filecount=${filecounts[${AUDIO_TARFILE}]} +actual_filecount=$(find "${AUDIO16K_DIR}" -name '*.wav' 2>/dev/null | wc -l) || : +write_filecount=false + +if [ "${actual_filecount}" -ne 0 ] && [ "${actual_filecount}" -eq "${expected_filecount}" ] ; then + echo " Existing ${AUDIO16K_DIR} valid" +else + actual_filecount=$(find "${AUDIO_DIR}" -name "${AUDIO_IN_GLOB}" 2>/dev/null | wc -l) || : + if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then + if [ ! -f "${AUDIO_TAR}" ] ; then + echo " Downloading ${AUDIO_TARFILE}" + curl -sfL "${AUDIO_URL}" -o "${AUDIO_TAR}" + fi + + rm -rf "${AUDIO_DIR}" || : + mkdir -p "${AUDIO_DIR}" || : + echo " Untarring ${AUDIO_TARFILE}" + tar -xzf "${AUDIO_TAR}" -C "${AUDIO_DIR}" + fi + + if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_TAR}" ] ; then + echo " Cleaning up ${AUDIO_TARFILE}" + rm -rf "${AUDIO_TAR}" + fi + + converter + + actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || : + filecounts[${AUDIO_TARFILE}]="${actual_filecount}" + write_filecount=true +fi + +if ${write_filecount} ; then + write_filecounts filecounts "${AUDIO_FILECOUNT}" +fi + +if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_TAR}" ] ; then + echo " Cleaning up ${AUDIO_TARFILE}" + rm -rf "${AUDIO_TAR}" +fi + +if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ] ; then + echo " Cleaning up ${AUDIO_DIR}" + rm -rf "${AUDIO_DIR}" +fi + +echo " CHiME complete" +exit 0 diff --git a/cli/setup_python_venv b/cli/setup_python_venv index 1386497..38e6d14 100755 --- a/cli/setup_python_venv +++ b/cli/setup_python_venv @@ -24,6 +24,13 @@ Options: --verbose: Print the detailed "pip install" output. +Environment overrides: +MWW_TF_SPEC: Full TensorFlow package spec (e.g. "tf-nightly[and-cuda]" + or "tensorflow[and-cuda]==2.20.0"). +MWW_TENSORBOARD_SPEC: Comma-separated TensorBoard package specs. + Example: "tensorboard==2.20.0,tensorboard-data-server==0.7.2" +MWW_KERAS_SPEC: Keras package spec to install explicitly. + EOF exit 1 fi @@ -46,6 +53,24 @@ cd "${DATA_DIR}" "${GPU}" || export CUDA_VISIBLE_DEVICES=-1 +detect_gpu_compute_capability() { + if command -v nvidia-smi >/dev/null 2>&1 ; then + nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | head -n 1 \ + | tr -d '[:space:]' + fi +} + +GPU_COMPUTE_CAPABILITY="" +IS_BLACKWELL=false +if ${GPU} ; then + GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability || true)" + case "${GPU_COMPUTE_CAPABILITY}" in + 12.*) IS_BLACKWELL=true ;; + esac + ${IS_BLACKWELL} && echo " Blackwell GPU detected (compute capability ${GPU_COMPUTE_CAPABILITY})" +fi + VENV="${DATA_DIR}/.venv" [ -n "${VIRTUAL_ENV}" ] && deactivate @@ -127,9 +152,34 @@ echo " ===== Installing common requirements =====" pip_install -r "${ROOTDIR}/requirements.txt" ${GPU} && tfgpu='[and-cuda]' || tfgpu="" -echo " ===== Installing Tensorflow${tfgpu} =====" -pip_install ai_edge_litert "tensorflow${tfgpu}==2.20.0" "tensorboard==2.20.0" \ - "tensorboard-data-server==0.7.2" +declare -a default_tensorboard_specs=() + +if ${GPU} && ${IS_BLACKWELL} ; then + # Blackwell path: prefer nightly TF while upstream stable wheels catch up. + DEFAULT_TF_SPEC="tf-nightly${tfgpu}" + # Let tf-nightly resolve a compatible TensorBoard dependency by default. + default_tensorboard_specs=() +else + DEFAULT_TF_SPEC="tensorflow${tfgpu}==2.20.0" + default_tensorboard_specs=( "tensorboard==2.20.0" "tensorboard-data-server==0.7.2" ) +fi + +TF_SPEC="${MWW_TF_SPEC:-${DEFAULT_TF_SPEC}}" +declare -a tf_install_specs=( ai_edge_litert "${TF_SPEC}" ) + +if [ -n "${MWW_TENSORBOARD_SPEC:-}" ] ; then + IFS=',' read -r -a user_tb_specs <<< "${MWW_TENSORBOARD_SPEC}" + for tb_spec in "${user_tb_specs[@]}" ; do + tb_spec="${tb_spec#"${tb_spec%%[![:space:]]*}"}" + tb_spec="${tb_spec%"${tb_spec##*[![:space:]]}"}" + [ -n "${tb_spec}" ] && tf_install_specs+=( "${tb_spec}" ) + done +else + tf_install_specs+=( "${default_tensorboard_specs[@]}" ) +fi + +echo " ===== Installing TensorFlow stack (${TF_SPEC}) =====" +pip_install "${tf_install_specs[@]}" ${GPU} && torchgpu='--index-url https://download.pytorch.org/whl/cu129' || torchgpu="" echo " ===== Installing torch and torchaudio ${torchgpu:+[cuda]} =====" @@ -203,8 +253,15 @@ echo " ===== Installing onnxruntime${onnxgpu} =====" pip_install "onnxruntime${onnxgpu}>=1.16.0" echo " ===== Installing keras =====" -# keras 3.13 has "issues" so we need to back down to 3.12. -pip_install "keras==3.12.0" +# Default: keep the known-good pin with stable TF 2.20. +# For tf-nightly/custom TF specs, skip this pin unless explicitly requested. +if [ -n "${MWW_KERAS_SPEC:-}" ] ; then + pip_install "${MWW_KERAS_SPEC}" +elif [ -n "${MWW_TF_SPEC:-}" ] || [[ "${TF_SPEC}" == tf-nightly* ]] ; then + echo " Skipping explicit keras pin for ${TF_SPEC} (set MWW_KERAS_SPEC to force one)." +else + pip_install "keras==3.12.0" +fi # ----------------------------------------------------------------------------- # Optional CUDA data dir (GPU-only) @@ -240,4 +297,4 @@ END_TS=$EPOCHSECONDS echo "Run 'source ${VENV}/bin/activate' to activate the new virtualenv in the current shell." -print_elapsed_time "${START_TS}" "${END_TS}" "Python package installation complete" \ No newline at end of file +print_elapsed_time "${START_TS}" "${END_TS}" "Python package installation complete" diff --git a/cli/setup_training_datasets b/cli/setup_training_datasets index c343e95..c2f4940 100755 --- a/cli/setup_training_datasets +++ b/cli/setup_training_datasets @@ -61,5 +61,15 @@ echo -e "\n===== Setting up Training Datasets =====\n" --cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \ --data-dir="${DATA_DIR}" +"${PROGDIR}/setup_wham" \ + --cleanup-archives="${CLEANUP_ARCHIVES}" \ + --cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \ + --data-dir="${DATA_DIR}" + +"${PROGDIR}/setup_chime" \ + --cleanup-archives="${CLEANUP_ARCHIVES}" \ + --cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \ + --data-dir="${DATA_DIR}" + END_TS=$EPOCHSECONDS print_elapsed_time "${START_TS}" "${END_TS}" "Training dataset setup" diff --git a/cli/setup_wham b/cli/setup_wham new file mode 100755 index 0000000..9f1c69c --- /dev/null +++ b/cli/setup_wham @@ -0,0 +1,142 @@ +#!/bin/bash +set -euo pipefail + +PROGPATH=$(realpath "$0") +PROGDIR=$(dirname "${PROGPATH}") + +source "${PROGDIR}/shell.functions" + +if [ "${HELP}" == "true" ] ; then + cat <&2 +Usage: $0 [ --cleanup-archives ] [ --cleanup-input-files ] [ --data-dir= ] + + --cleanup-archives + : Automatically clean up any downloaded archives after + : extraction. + --cleanup-intermediate-files + : Automatically clean up intermediate extracted files + : after conversion to 16k. + : Path to the data directory. + : Default: ${DATA_DIR} + +EOF + exit 1 +fi + +mkdir -p "${DATA_DIR}/training_datasets/downloads" || : +cd "${DATA_DIR}/training_datasets" + +echo "***** Checking WHAM *****" + +AUDIO_URL="https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/wham_noise.zip" +AUDIO_ZIPFILE="wham_noise.zip" +AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}" +AUDIO_DIR="./wham" +mkdir -p "${AUDIO_DIR}" || : +AUDIO16K_DIR="./wham_16k" +mkdir -p "${AUDIO16K_DIR}" || : +AUDIO_FILECOUNT="./downloads/wham_filecount" +AUDIO_IN_GLOB="*.wav" + +declare -A filecounts=( [${AUDIO_ZIPFILE}]=0 ) +get_filecounts filecounts "${AUDIO_FILECOUNT}" + +converter() { + source "${DATA_DIR}/.venv/bin/activate" + + python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF +import os, sys +from pathlib import Path +import numpy as np +import scipy.io.wavfile +import librosa +from tqdm import tqdm + +def write_wav(dst: Path, data: np.ndarray, sr: int): + dst.parent.mkdir(parents=True, exist_ok=True) + x = np.clip(data, -1.0, 1.0) + scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16)) + +def flatten_name(root: Path, src: Path) -> str: + rel = src.relative_to(root) + return "__".join(rel.parts) + +wham_in = Path(sys.argv[1]).resolve() +wham_out = Path(sys.argv[2]).resolve() + +wavs = list(wham_in.rglob("*.wav")) +print(f" WAV files: {len(wavs)}") +print(" Converting WHAM -> 16k mono WAV") + +bad = [] +ok = 0 +skipped = 0 +for p in tqdm(wavs, desc=" WHAM -> WAV (resample 16k mono)"): + try: + out_name = flatten_name(wham_in, p) + outfile = wham_out / out_name + if outfile.exists(): + skipped += 1 + continue + y, _ = librosa.load(p, sr=16000, mono=True) + if y.size == 0: + raise ValueError("empty audio") + write_wav(outfile, y, 16000) + ok += 1 + except Exception as e: + bad.append(f"{p}:{e}") + +if bad: + (wham_out / "wham_corrupted_files.log").write_text("\\n".join(bad)) +print(f" WHAM complete ({ok} ok, {skipped} skipped, {len(bad)} failed)") +EOF +} + +expected_filecount=${filecounts[${AUDIO_ZIPFILE}]} +actual_filecount=$(find "${AUDIO16K_DIR}" -name '*.wav' 2>/dev/null | wc -l) || : +write_filecount=false + +if [ "${actual_filecount}" -ne 0 ] && [ "${actual_filecount}" -eq "${expected_filecount}" ] ; then + echo " Existing ${AUDIO16K_DIR} valid" +else + actual_filecount=$(find "${AUDIO_DIR}" -name "${AUDIO_IN_GLOB}" 2>/dev/null | wc -l) || : + if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then + if [ ! -f "${AUDIO_ZIP}" ] ; then + echo " Downloading ${AUDIO_ZIPFILE}" + curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}" + fi + + rm -rf "${AUDIO_DIR}" || : + mkdir -p "${AUDIO_DIR}" || : + echo " Unzipping ${AUDIO_ZIPFILE}" + unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}" + fi + + if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then + echo " Cleaning up ${AUDIO_ZIPFILE}" + rm -rf "${AUDIO_ZIP}" + fi + + converter + + actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || : + filecounts[${AUDIO_ZIPFILE}]="${actual_filecount}" + write_filecount=true +fi + +if ${write_filecount} ; then + write_filecounts filecounts "${AUDIO_FILECOUNT}" +fi + +if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then + echo " Cleaning up ${AUDIO_ZIPFILE}" + rm -rf "${AUDIO_ZIP}" +fi + +if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ] ; then + echo " Cleaning up ${AUDIO_DIR}" + rm -rf "${AUDIO_DIR}" +fi + +echo " WHAM complete" +exit 0 diff --git a/cli/wake_word_sample_augmenter b/cli/wake_word_sample_augmenter index 0ac08ef..9e635af 100644 --- a/cli/wake_word_sample_augmenter +++ b/cli/wake_word_sample_augmenter @@ -23,6 +23,8 @@ parser.add_argument("--personal-output-dir", type=str, help="Personal features o parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: /training_datasets/mit_rirs_16k", required=False) parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: /training_datasets/fma_16k", required=False) parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: /training_datasets/audioset_16k", required=False) +parser.add_argument("--wham-16k-dir", type=str, help="WHAM input directory. Default: /training_datasets/wham_16k", required=False) +parser.add_argument("--chime-16k-dir", type=str, help="CHiME input directory. Default: /training_datasets/chime_16k", required=False) try: args = parser.parse_args() @@ -71,6 +73,16 @@ if not args.audioset_16k_dir: else: args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir) +if not args.wham_16k_dir: + args.wham_16k_dir = os.path.join(args.data_dir, "training_datasets", "wham_16k") +else: + args.wham_16k_dir = os.path.realpath(args.wham_16k_dir) + +if not args.chime_16k_dir: + args.chime_16k_dir = os.path.join(args.data_dir, "training_datasets", "chime_16k") +else: + args.chime_16k_dir = os.path.realpath(args.chime_16k_dir) + def validate_directories(paths): for path in paths: if not os.path.exists(path): @@ -78,7 +90,15 @@ def validate_directories(paths): return False return True -required = [work_dir, args.input_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir] +required = [ + work_dir, + args.input_dir, + args.mit_rirs_16k_dir, + args.wham_16k_dir, + args.chime_16k_dir, + args.fma_16k_dir, + args.audioset_16k_dir, +] if not validate_directories(required): parser.print_help() sys.exit(1) @@ -117,7 +137,12 @@ from microwakeword.audio.spectrograms import SpectrogramGeneration START_TIME = datetime.now(timezone.utc).replace(microsecond=0) impulse_paths = [args.mit_rirs_16k_dir] -background_paths = [args.fma_16k_dir, args.audioset_16k_dir] +background_paths = [ + args.wham_16k_dir, + args.chime_16k_dir, + args.fma_16k_dir, + args.audioset_16k_dir, +] augmenter = Augmentation( augmentation_duration_s=3.2, @@ -245,4 +270,4 @@ END_TIME = datetime.now(timezone.utc).replace(microsecond=0) et = END_TIME - START_TIME print(f"\n{'=' * 80}") print(f"{'Augmentation completed.':>50s} Elapsed time: {et!s}") -print(f"{'=' * 80}\n") \ No newline at end of file +print(f"{'=' * 80}\n") diff --git a/cli/wake_word_sample_trainer b/cli/wake_word_sample_trainer index e32be53..33fdf64 100644 --- a/cli/wake_word_sample_trainer +++ b/cli/wake_word_sample_trainer @@ -51,23 +51,46 @@ fi # shellcheck source=/dev/null source "${DATA_DIR}/.venv/bin/activate" -# --- WSL2 GPU visibility fix (venv sometimes doesn't inherit WSL driver path) --- -# Keep a copy so we can restore/preserve on fallback if desired. +# Keep copies so we can restore/preserve across retries and fallback. ORIG_XLA_FLAGS="${XLA_FLAGS:-}" +ORIG_TF_XLA_FLAGS="${TF_XLA_FLAGS:-}" -if [ -d /usr/lib/wsl/lib ]; then - export LD_LIBRARY_PATH="/usr/lib/wsl/lib:${LD_LIBRARY_PATH:-}" - echo "ℹ️ WSL2 detected: LD_LIBRARY_PATH+=/usr/lib/wsl/lib" +normalize_bool() { + case "${1,,}" in + 1|true|yes|on) echo "true" ;; + *) echo "false" ;; + esac +} - # Blackwell / PTXAS workaround: only apply on WSL *and* only if user didn't set XLA_FLAGS +detect_gpu_compute_capability() { + if command -v nvidia-smi >/dev/null 2>&1; then + nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | head -n 1 \ + | tr -d '[:space:]' + fi +} + +GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)" +IS_BLACKWELL="false" +case "${GPU_COMPUTE_CAPABILITY}" in + 12.*) IS_BLACKWELL="true" ;; +esac + +ALLOW_CPU_FALLBACK_DEFAULT="true" +ALLOW_CPU_FALLBACK="$(normalize_bool "${MWW_ALLOW_CPU_FALLBACK:-${ALLOW_CPU_FALLBACK_DEFAULT}}")" + +if [ "${IS_BLACKWELL}" = "true" ]; then + echo "ℹ️ Blackwell GPU detected (compute capability ${GPU_COMPUTE_CAPABILITY})." + echo "ℹ️ Using GPU compatibility retries; CPU fallback is ${ALLOW_CPU_FALLBACK} (override with MWW_ALLOW_CPU_FALLBACK=true|false)." + + # Force driver PTX fallback when XLA needs ptxas. if [ -z "${XLA_FLAGS:-}" ]; then export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found" - echo "ℹ️ WSL2: setting XLA_FLAGS=${XLA_FLAGS}" + echo "ℹ️ Setting XLA_FLAGS=${XLA_FLAGS}" else echo "ℹ️ Using user-provided XLA_FLAGS=${XLA_FLAGS}" fi fi -# ----------------------------------------------------------------------------- check_directories() { for d in "$@" ; do @@ -226,6 +249,8 @@ GPU_FALLBACK_MARKERS=( "oom" "out of memory" "cuda_error_out_of_memory" + "cuda_error_invalid_handle" + "culaunchkernel" "failed to allocate" "cudnn" "cublas" @@ -255,52 +280,104 @@ run_attempt() { return ${PIPESTATUS[0]} } +is_gpu_runtime_failure() { + local log_lc m + log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)" + + for m in "${GPU_FALLBACK_MARKERS[@]}"; do + if echo "${log_lc}" | grep -qF "${m}"; then + return 0 + fi + done + + # Catch unlisted TF GPU runtime failures (common on newer architectures). + if echo "${log_lc}" | grep -qF "device:gpu:0" \ + && echo "${log_lc}" | grep -qF "internalerror"; then + return 0 + fi + + return 1 +} + # --------- ENV (keep compatible; DO NOT add unsupported XLA flags) ---------- export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}" export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}" export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}" export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}" -export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}" +if [ "${IS_BLACKWELL}" = "true" ]; then + # TF 2.20 + Blackwell is often unstable with cuda_malloc_async. + unset TF_GPU_ALLOCATOR +else + export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}" +fi -if run_attempt "Attempt 1/2: GPU training (allow_growth + cuda_malloc_async)" ; then +TRAINING_DONE="false" + +if run_attempt "Attempt 1/3: GPU training (default runtime profile)" ; then echo "✅ Training complete (GPU path)." + TRAINING_DONE="true" else echo "⚠️ GPU attempt failed. Checking whether this looks like a GPU/OOM/runtime failure…" - log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)" - looks_like_gpu_fail="false" - for m in "${GPU_FALLBACK_MARKERS[@]}"; do - if echo "${log_lc}" | grep -qF "${m}"; then - looks_like_gpu_fail="true" - break - fi - done + if ! is_gpu_runtime_failure; then + echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2 + exit 1 + fi - if [ "${looks_like_gpu_fail}" = "true" ]; then - echo "↪️ Detected GPU/OOM/runtime failure markers. Falling back to CPU." + if [ "${IS_BLACKWELL}" = "true" ]; then + echo "↪️ Retrying on GPU with Blackwell compatibility profile (BFC allocator + driver PTX fallback)." + + unset TF_GPU_ALLOCATOR + export TF_XLA_FLAGS="${ORIG_TF_XLA_FLAGS:---tf_xla_auto_jit=0}" + if run_attempt "Attempt 2/3: GPU training (Blackwell compatibility profile)" ; then + echo "✅ Training complete (GPU Blackwell compatibility profile)." + TRAINING_DONE="true" + else + if is_gpu_runtime_failure; then + echo "↪️ Retrying on GPU with minimal runtime knobs (no TF_XLA_FLAGS)." + + unset TF_GPU_ALLOCATOR + unset TF_XLA_FLAGS + if run_attempt "Attempt 3/3: GPU training (Blackwell minimal runtime profile)" ; then + echo "✅ Training complete (GPU Blackwell minimal profile)." + TRAINING_DONE="true" + fi + fi + fi + fi +fi + +if [ "${TRAINING_DONE}" != "true" ]; then + if ! is_gpu_runtime_failure; then + echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2 + exit 1 + fi + + if [ "${ALLOW_CPU_FALLBACK}" = "true" ]; then + echo "↪️ Detected GPU runtime failure markers. Falling back to CPU (MWW_ALLOW_CPU_FALLBACK=true)." export CUDA_VISIBLE_DEVICES="" unset TF_GPU_ALLOCATOR - - # CPU attempt should not inherit GPU/XLA runtime knobs unset TF_XLA_FLAGS - # Optional: clear XLA_FLAGS for CPU (usually irrelevant). If user had set it, restore. + # CPU attempt should not inherit GPU-specific XLA flags. if [ -n "${ORIG_XLA_FLAGS}" ]; then export XLA_FLAGS="${ORIG_XLA_FLAGS}" else unset XLA_FLAGS fi - if run_attempt "Attempt 2/2: CPU fallback (CUDA_VISIBLE_DEVICES='')" ; then + if run_attempt "CPU fallback: training (CUDA_VISIBLE_DEVICES='')" ; then echo "✅ Training complete (CPU fallback)." else - echo "❌ Training failed on BOTH GPU and CPU. See: ${TRAIN_LOG}" >&2 + echo "❌ Training failed on both GPU retries and CPU fallback. See: ${TRAIN_LOG}" >&2 exit 1 fi else - echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2 + echo "❌ GPU training failed after compatibility retries. CPU fallback is disabled." >&2 + echo " To allow CPU fallback, set MWW_ALLOW_CPU_FALLBACK=true." >&2 + echo " See: ${TRAIN_LOG}" >&2 exit 1 fi fi @@ -349,4 +426,4 @@ echo "Metadata: ${json_path}" echo END_TS=$EPOCHSECONDS print_elapsed_time "${START_TS}" "${END_TS}" "Training completed." -echo \ No newline at end of file +echo diff --git a/train_wake_word b/train_wake_word index 5fb7e99..ddb058d 100755 --- a/train_wake_word +++ b/train_wake_word @@ -74,14 +74,36 @@ START_TS=$EPOCHSECONDS # ----------------------------------------------------------------------------- # TensorFlow / XLA environment (known-good, portable) # ----------------------------------------------------------------------------- -export TF_CPP_MIN_LOG_LEVEL=9 -export TF_FORCE_GPU_ALLOW_GROWTH=true -export TF_GPU_ALLOCATOR=cuda_malloc_async +detect_gpu_compute_capability() { + if command -v nvidia-smi >/dev/null 2>&1 ; then + nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | head -n 1 \ + | tr -d '[:space:]' + fi +} + +GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)" +IS_BLACKWELL=false +case "${GPU_COMPUTE_CAPABILITY}" in + 12.*) IS_BLACKWELL=true ;; +esac + +export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-9}" +export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}" # Hard-set TF_XLA_FLAGS to ONLY what we know this build supports. # Do NOT append user environment flags (can cause hard failures). -export TF_XLA_FLAGS="--tf_xla_auto_jit=0" -unset XLA_FLAGS +export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}" + +if ${IS_BLACKWELL} ; then + # TF 2.20 + Blackwell is often unstable with cuda_malloc_async. + unset TF_GPU_ALLOCATOR + [ -z "${XLA_FLAGS:-}" ] && export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found" + echo "ℹ️ Blackwell detected (compute capability ${GPU_COMPUTE_CAPABILITY}): using compatibility GPU defaults." +else + export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}" + unset XLA_FLAGS +fi export NVIDIA_TF32_OVERRIDE=1 export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512 @@ -141,4 +163,4 @@ print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augmen print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps" python -c $'msg="="*54 ; print(f"{msg:>80s}")' print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total" -python -c $'print(f"{\'=\' * 80}")' \ No newline at end of file +python -c $'print(f"{\'=\' * 80}")'