Skip to content

increased a bit tolerance for pytorch/distributed/run_numerics.py#3095

Open
francesco-bertolotti wants to merge 1 commit into
NVIDIA:mainfrom
francesco-bertolotti:f14-increase-tolerance-pytorch-distributed
Open

increased a bit tolerance for pytorch/distributed/run_numerics.py#3095
francesco-bertolotti wants to merge 1 commit into
NVIDIA:mainfrom
francesco-bertolotti:f14-increase-tolerance-pytorch-distributed

Conversation

@francesco-bertolotti
Copy link
Copy Markdown
Contributor

Description

tests/pytorch/distributed/test_numerics.py::test_distributed[None] (the unquantized fp32/TF32 configuration) fails on A100 in the TransformerLayer gradient check:

[rank0] NUMERICAL CHECK FAILED: 12 not close enough at index 352 with
-0.0003907721256837249 vs -0.0003680467198137194 |
rel. error = 0.061745981275169545 (tol = 0.001) |
abs. error = 2.272540587000549e-05 (tol = 1e-05)

Parameter index 12 is layernorm_mlp.fc1_weight; the failure is a single near-zero gradient element in _test_transformer_layer_parallel(sequence_parallel=False).

Reproduction (observed on 4x A100 / sm80, TF32 enabled):

torchrun --nproc_per_node=4 tests/pytorch/distributed/run_numerics.py

Fix

Raise the fp32 atol from 1e-5 to 5e-5 (2x headroom over the observed miss), keeping rtol = 1e-3 unchanged.

     if dtype == torch.float32:
-        # TF32 has same mantissa bits as FP16
-        return {"rtol": 1e-3, "atol": 1e-5}
+        # TF32 has same mantissa bits as FP16. The atol is looser than for FP16
+        # because near-zero gradient elements can differ by a few 1e-5 between
+        # the TP-sharded and single-device GEMM reduction orders (observed on A100).
+        return {"rtol": 1e-3, "atol": 5e-5}

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Raise the fp32 (TF32) absolute tolerance in run_numerics.py::_get_tolerances from 1e-5 to 5e-5, with a comment explaining the TP-sharded vs single-GPU reduction-order noise on near-zero gradient elements. rtol and the fp16/bf16/quantized tolerances are unchanged.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 5, 2026

Greptile Summary

This PR loosens the absolute tolerance for torch.float32 numerical comparisons in the distributed numerics test from 1e-5 to 5e-5, fixing a spurious test failure on A100 GPUs where TF32 GEMM reduction-order differences between TP-sharded and single-device paths produce near-zero gradient elements that fall just outside the old bound.

  • _get_tolerances in run_numerics.py: atol for float32 raised to 5e-5 (~2.2× the observed worst-case error of 2.27e-5); rtol and all other dtype tolerances (float16, bfloat16, quantized) are untouched.
  • The updated comment explains the physical reason for the looser bound (TP-sharded vs single-GPU GEMM reduction order on A100 / sm80).

Confidence Score: 5/5

Safe to merge — it is a one-line tolerance bump in a test helper with no impact on library code.

The change is narrowly scoped to a single constant in a test utility function. The new value of 5e-5 is calibrated to the observed worst-case error (2.27e-5) with ~2.2x headroom, and the comment clearly attributes the looser bound to TF32 reduction-order noise in TP-sharded GEMMs. No library code, model logic, or other test configurations are touched.

No files require special attention.

Important Files Changed

Filename Overview
tests/pytorch/distributed/run_numerics.py Raises fp32/TF32 absolute tolerance in _get_tolerances from 1e-5 to 5e-5 with an explanatory comment; rtol and all other dtype tolerances are unchanged.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["_get_tolerances(dtype)"] --> B{QUANTIZATION?}
    B -->|fp8_cs| C["rtol=0.4, atol=0.25"]
    B -->|nvfp4| D["rtol=0.125, atol=0.12"]
    B -->|other quantized| E["rtol=0.125, atol=0.0625"]
    B -->|None| F{dtype}
    F -->|float16| G["rtol=1e-3, atol=1e-5"]
    F -->|bfloat16| H["rtol=1.6e-2, atol=1e-5"]
    F -->|float32 / TF32| I["rtol=1e-3, atol=5e-5 (changed from 1e-5)"]
    F -->|other| J["raise ValueError"]
Loading

Reviews (1): Last reviewed commit: "increased a bit tolerance for pytorch/di..." | Re-trigger Greptile

Comment on lines +210 to +212
# TF32 has same mantissa bits as FP16. The atol is looser than for FP16
# because near-zero gradient elements can differ by a few 1e-5 between
# the TP-sharded and single-device GEMM reduction orders (observed on A100).
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit disturbing. Even if TF32 has errors, shouldn't it be strictly better than FP16?

This makes me think there are other differences going on, like maybe the FP32 GEMM kernel is different between TP and non-TP, while it is consistent for FP16?

@francesco-bertolotti
Copy link
Copy Markdown
Contributor Author

francesco-bertolotti commented Jun 6, 2026

Hi @timmoon10,

I did go down this rabbit hole. This is what I think is going on.

The failing test runs a sharded transformer layer (with TP=4) against an unsharded one. If they match the test pass otherwise it fails. The test runs with several configurations. The one that it fails is only the one with ReLU activation.

Next, I ran both the sharded transformer and unsharded one separately with CUBLASLT_LOG_LEVEL=5. And I found a small difference in the algorithm selection:

# the sharded one
[2026-06-06 15:31:42][cublasLt][321674][Trace][cublasLtMatmul] 
A=0X153DBE21A800 Adesc=[type=R_32F rows=8 cols=64 ld=8] 
B=0X153DBE3B6000  Bdesc=[type=R_32F rows=8 cols=1024 ld=8] 
C=0X153D47880000  Cdesc=[type=R_32F rows=64 cols=1024 ld=64] 
D=0X153D47880000 Ddesc=[type=R_32F rows=64 cols=1024 ld=64] 
computeDesc=[computeType=COMPUTE_32F_FAST_TF32 scaleType=R_32F transa=OP_T] 
algo=[algoId=0 tile=MATMUL_TILE_128x32 ctaSwizzling=1] 
workSpace=0X153D4E000000 workSpaceSizeInBytes=4194304  beta=0  outOfPlace=0 stream=0X0
# the unsharded one
[2026-06-06 15:30:06][cublasLt][321377][Trace][cublasLtMatmul] 
A=0X1530C6213000 Adesc=[type=R_32F rows=32 cols=64 ld=32] 
B=0X1530C6358A00 Bdesc=[type=R_32F rows=32 cols=1024 ld=32] 
C=0X1530C6378A00 Cdesc=[type=R_32F rows=64 cols=1024 ld=64] 
D=0X1530C6378A00 Ddesc=[type=R_32F rows=64 cols=1024 ld=64] 
computeDesc=[computeType=COMPUTE_32F_FAST_TF32 scaleType=R_32F transa=OP_T] 
algo=[algoId=21 tile=MATMUL_TILE_64x64 stages=MATMUL_STAGES_16x6] 
workSpace=0X153056000000workSpaceSizeInBytes=4194304 beta=0 outOfPlace=0 stream=0X0

In particular, the sharded transformer selects algoId=0 while the unsharded one selects algoId=21. I have looked at the numerical implementation flags with cublasLtMatmulAlgoCapGetAttribute for both: algoId=0 has 0x80201 meanwhile algoId=21 has 0x40202.

Looking at cublasLt.h, these flags maps from:

#define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA             (0x01ull << 0)   // = 0x00001
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA            (0x02ull << 0)   // = 0x00002
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8)   // = 0x00200
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32      (0x04ull << 16)  // = 0x40000
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F       (0x08ull << 16)  // = 0x80000

which means:

Algo Input Accumulator Operation
algo 0 (0x80201) INPUT_32F (0x80000) ACCUMULATOR_32F (0x00200) FMA (0x00001)
algo 21 (0x40202) INPUT_TF32 (0x40000) ACCUMULATOR_32F (0x00200) HMMA (0x00002)

This is not exactly easy to read. However, I believe this means that algo 0 runs in full precision while algo 21 runs in half precision.

For further confirmation, I have set the feedforward size in the transformer layer to 256. This leads to selecting algo 21 also for the sharded case:

[2026-06-06 14:52:43][cublasLt][317853][Trace][cublasLtMatmul] 
A=0X151072213000 Adesc=[type=R_32F rows=32 cols=64 ld=32] 
B=0X15107239AC00 Bdesc=[type=R_32F rows=32 cols=1024 ld=32] 
C=0X151007600000 Cdesc=[type=R_32F rows=64 cols=1024 ld=64] 
D=0X151007600000 Ddesc=[type=R_32F rows=64 cols=1024 ld=64] computeDesc=[computeType=COMPUTE_32F_FAST_TF32 scaleType=R_32F transa=OP_T] 
algo=[algoId=21 tile=MATMUL_TILE_64x64 stages=MATMUL_STAGES_16x6] 
workSpace=0X151006000000 workSpaceSizeInBytes=4194304 beta=0 outOfPlace=0 stream=0X0

which ultimately makes the test pass.


Given all of this, I think that it is fair to increase the tolerance a bit to make this test pass. On the other hand, we can make it pass also by increasing the feedforward size to force the same algoId selection by cublas. Let me know which you prefer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants