diff --git a/cli/wake_word_sample_trainer b/cli/wake_word_sample_trainer index a9e55dc..e32be53 100644 --- a/cli/wake_word_sample_trainer +++ b/cli/wake_word_sample_trainer @@ -51,6 +51,24 @@ fi # shellcheck source=/dev/null source "${DATA_DIR}/.venv/bin/activate" +# --- WSL2 GPU visibility fix (venv sometimes doesn't inherit WSL driver path) --- +# Keep a copy so we can restore/preserve on fallback if desired. +ORIG_XLA_FLAGS="${XLA_FLAGS:-}" + +if [ -d /usr/lib/wsl/lib ]; then + export LD_LIBRARY_PATH="/usr/lib/wsl/lib:${LD_LIBRARY_PATH:-}" + echo "ℹ️ WSL2 detected: LD_LIBRARY_PATH+=/usr/lib/wsl/lib" + + # Blackwell / PTXAS workaround: only apply on WSL *and* only if user didn't set XLA_FLAGS + if [ -z "${XLA_FLAGS:-}" ]; then + export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found" + echo "ℹ️ WSL2: setting XLA_FLAGS=${XLA_FLAGS}" + else + echo "ℹ️ Using user-provided XLA_FLAGS=${XLA_FLAGS}" + fi +fi +# ----------------------------------------------------------------------------- + check_directories() { for d in "$@" ; do [ -d "$d" ] || { echo "ERROR: Directory $d not found" >&2 ; exit 1 ; } @@ -162,7 +180,6 @@ if [ "${HAS_PERSONAL}" = "true" ]; then type: mmap EOF )" - perl -0777 -i -pe "s#__PERSONAL_FEATURE_MARKER__#${personal_block}#g" "${YAML_PATH}" else sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}" @@ -241,7 +258,6 @@ run_attempt() { # --------- ENV (keep compatible; DO NOT add unsupported XLA flags) ---------- export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}" export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}" -unset XLA_FLAGS export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}" export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}" @@ -269,7 +285,13 @@ else # CPU attempt should not inherit GPU/XLA runtime knobs unset TF_XLA_FLAGS - unset XLA_FLAGS + + # Optional: clear XLA_FLAGS for CPU (usually irrelevant). If user had set it, restore. + if [ -n "${ORIG_XLA_FLAGS}" ]; then + export XLA_FLAGS="${ORIG_XLA_FLAGS}" + else + unset XLA_FLAGS + fi if run_attempt "Attempt 2/2: CPU fallback (CUDA_VISIBLE_DEVICES='')" ; then echo "✅ Training complete (CPU fallback)."