This commit is contained in:
MasterPhooey
2026-01-21 06:14:32 -06:00
parent 423bbd15f5
commit 98a087c87d
2 changed files with 32 additions and 19 deletions

View File

@@ -178,24 +178,31 @@ echo " ===== Installing keras ====="
# keras 3.13 has "issues" so we need to back down to 3.12. # keras 3.13 has "issues" so we need to back down to 3.12.
pip_install "keras==3.12.0" pip_install "keras==3.12.0"
CUDA_DATA_DIR="${DATA_DIR}/cuda" # -----------------------------------------------------------------------------
LIBDEVICE_DIR="${CUDA_DATA_DIR}/nvvm/libdevice" # Optional CUDA data dir (GPU-only)
mkdir -p "${LIBDEVICE_DIR}" # Some stacks expect a CUDA "nvvm/libdevice" tree. We create one in /data/cuda
TRITON_LIBDEVICE="$( # and link Triton's libdevice if it exists. This is safe and does NOT enable
python - <<'PY' # any extra XLA flags by itself.
import glob # -----------------------------------------------------------------------------
import sys if ${GPU} ; then
CUDA_DATA_DIR="${DATA_DIR}/cuda"
LIBDEVICE_DIR="${CUDA_DATA_DIR}/nvvm/libdevice"
mkdir -p "${LIBDEVICE_DIR}"
TRITON_LIBDEVICE="$(
python - <<'PY'
import glob
paths = glob.glob("**/site-packages/triton/backends/nvidia/lib/libdevice.10.bc", recursive=True) paths = glob.glob("**/site-packages/triton/backends/nvidia/lib/libdevice.10.bc", recursive=True)
print(paths[0] if paths else "", end="") print(paths[0] if paths else "", end="")
PY PY
)" )"
if [ -n "${TRITON_LIBDEVICE}" ] ; then if [ -n "${TRITON_LIBDEVICE}" ] ; then
ln -sf "${TRITON_LIBDEVICE}" "${LIBDEVICE_DIR}/libdevice.10.bc" ln -sf "${TRITON_LIBDEVICE}" "${LIBDEVICE_DIR}/libdevice.10.bc"
echo " Linked Triton libdevice.10.bc to ${LIBDEVICE_DIR}" echo " Linked Triton libdevice.10.bc to ${LIBDEVICE_DIR}"
else else
echo " Triton libdevice.10.bc not found; XLA may require --xla_gpu_cuda_data_dir" echo " Triton libdevice.10.bc not found (ok)"
fi
fi fi
"${PROGDIR}/test_python" --data-dir="${DATA_DIR}" "${PROGDIR}/test_python" --data-dir="${DATA_DIR}"
@@ -205,4 +212,4 @@ END_TS=$EPOCHSECONDS
echo "Run 'source ${VENV}/bin/activate' to activate the new virtualenv in the current shell." echo "Run 'source ${VENV}/bin/activate' to activate the new virtualenv in the current shell."
print_elapsed_time "${START_TS}" "${END_TS}" "Python package installation complete" print_elapsed_time "${START_TS}" "${END_TS}" "Python package installation complete"

View File

@@ -67,17 +67,23 @@ echo "===== Running '${WAKE_WORD}(${WAKE_WORD_TITLE})' generation, augmentation
echo echo
START_TS=$EPOCHSECONDS START_TS=$EPOCHSECONDS
# -----------------------------------------------------------------------------
# TensorFlow / XLA environment (known-good, portable)
# -----------------------------------------------------------------------------
export TF_CPP_MIN_LOG_LEVEL=9 export TF_CPP_MIN_LOG_LEVEL=9
export TF_FORCE_GPU_ALLOW_GROWTH=true export TF_FORCE_GPU_ALLOW_GROWTH=true
export TF_GPU_ALLOCATOR=cuda_malloc_async export TF_GPU_ALLOCATOR=cuda_malloc_async
DEFAULT_XLA_FLAGS="--tf_xla_auto_jit=0 --xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found --xla_gpu_cuda_data_dir=${DATA_DIR}/cuda"
DEFAULT_XLA_RUNTIME_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found --xla_gpu_cuda_data_dir=${DATA_DIR}/cuda" # Hard-set TF_XLA_FLAGS to ONLY what we know this build supports.
export TF_XLA_FLAGS="${TF_XLA_FLAGS:+${TF_XLA_FLAGS} }${DEFAULT_XLA_FLAGS}" # Do NOT append user environment flags (can cause hard failures).
export XLA_FLAGS="${XLA_FLAGS:+${XLA_FLAGS} }${DEFAULT_XLA_RUNTIME_FLAGS}" export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
unset XLA_FLAGS
export NVIDIA_TF32_OVERRIDE=1 export NVIDIA_TF32_OVERRIDE=1
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512 export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
export GLOG_minloglevel=2 export GLOG_minloglevel=2
export GRPC_VERBOSITY=ERROR export GRPC_VERBOSITY=ERROR
# -----------------------------------------------------------------------------
"${CLIDIR}/wake_word_sample_generator" \ "${CLIDIR}/wake_word_sample_generator" \
--samples=${SAMPLES} \ --samples=${SAMPLES} \
@@ -130,4 +136,4 @@ print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augmen
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps" print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
python -c $'msg="="*54 ; print(f"{msg:>80s}")' python -c $'msg="="*54 ; print(f"{msg:>80s}")'
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total" print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
python -c $'print(f"{\'=\' * 80}")' python -c $'print(f"{\'=\' * 80}")'