mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
blackwell/wham & chim datasets
This commit is contained in:
@@ -74,14 +74,36 @@ START_TS=$EPOCHSECONDS
|
||||
# -----------------------------------------------------------------------------
|
||||
# TensorFlow / XLA environment (known-good, portable)
|
||||
# -----------------------------------------------------------------------------
|
||||
export TF_CPP_MIN_LOG_LEVEL=9
|
||||
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
||||
export TF_GPU_ALLOCATOR=cuda_malloc_async
|
||||
detect_gpu_compute_capability() {
|
||||
if command -v nvidia-smi >/dev/null 2>&1 ; then
|
||||
nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
|
||||
| head -n 1 \
|
||||
| tr -d '[:space:]'
|
||||
fi
|
||||
}
|
||||
|
||||
GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)"
|
||||
IS_BLACKWELL=false
|
||||
case "${GPU_COMPUTE_CAPABILITY}" in
|
||||
12.*) IS_BLACKWELL=true ;;
|
||||
esac
|
||||
|
||||
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-9}"
|
||||
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
|
||||
|
||||
# Hard-set TF_XLA_FLAGS to ONLY what we know this build supports.
|
||||
# Do NOT append user environment flags (can cause hard failures).
|
||||
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
|
||||
unset XLA_FLAGS
|
||||
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
|
||||
|
||||
if ${IS_BLACKWELL} ; then
|
||||
# TF 2.20 + Blackwell is often unstable with cuda_malloc_async.
|
||||
unset TF_GPU_ALLOCATOR
|
||||
[ -z "${XLA_FLAGS:-}" ] && export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"
|
||||
echo "ℹ️ Blackwell detected (compute capability ${GPU_COMPUTE_CAPABILITY}): using compatibility GPU defaults."
|
||||
else
|
||||
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
|
||||
unset XLA_FLAGS
|
||||
fi
|
||||
|
||||
export NVIDIA_TF32_OVERRIDE=1
|
||||
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
|
||||
@@ -141,4 +163,4 @@ print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augmen
|
||||
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
|
||||
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
|
||||
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
|
||||
python -c $'print(f"{\'=\' * 80}")'
|
||||
python -c $'print(f"{\'=\' * 80}")'
|
||||
|
||||
Reference in New Issue
Block a user