blackwell/wham & chim datasets

This commit is contained in:
MasterPhooey
2026-03-09 19:48:35 -05:00
parent 4c4750a7bd
commit 94903783cb
7 changed files with 517 additions and 42 deletions

View File

@@ -74,14 +74,36 @@ START_TS=$EPOCHSECONDS
# -----------------------------------------------------------------------------
# TensorFlow / XLA environment (known-good, portable)
# -----------------------------------------------------------------------------
export TF_CPP_MIN_LOG_LEVEL=9
export TF_FORCE_GPU_ALLOW_GROWTH=true
export TF_GPU_ALLOCATOR=cuda_malloc_async
detect_gpu_compute_capability() {
if command -v nvidia-smi >/dev/null 2>&1 ; then
nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
| head -n 1 \
| tr -d '[:space:]'
fi
}
GPU_COMPUTE_CAPABILITY="$(detect_gpu_compute_capability)"
IS_BLACKWELL=false
case "${GPU_COMPUTE_CAPABILITY}" in
12.*) IS_BLACKWELL=true ;;
esac
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-9}"
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
# Hard-set TF_XLA_FLAGS to ONLY what we know this build supports.
# Do NOT append user environment flags (can cause hard failures).
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
unset XLA_FLAGS
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
if ${IS_BLACKWELL} ; then
# TF 2.20 + Blackwell is often unstable with cuda_malloc_async.
unset TF_GPU_ALLOCATOR
[ -z "${XLA_FLAGS:-}" ] && export XLA_FLAGS="--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found"
echo " Blackwell detected (compute capability ${GPU_COMPUTE_CAPABILITY}): using compatibility GPU defaults."
else
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
unset XLA_FLAGS
fi
export NVIDIA_TF32_OVERRIDE=1
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
@@ -141,4 +163,4 @@ print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augmen
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
python -c $'print(f"{\'=\' * 80}")'
python -c $'print(f"{\'=\' * 80}")'