Skip to content

[Bug Report] [macOS-arm64] Cached eager attention NaNs in transformers v5 — blocks bridge KV-cache generation #1322

@jlarson4

Description

@jlarson4

Summary

TransformerBridge.generate() produces all-NaN logits at the cached step on macOS-arm64 + transformers v5 + PyTorch 2.7.1, even though the bridge already forces eager attention, the cache contents are finite, and the model is on CPU. The same forward computation without past_key_values returns correct logits. The bug appears to live in HuggingFace transformers' cached eager-attention path (and/or PyTorch's CPU matmul on Apple Silicon), not in transformer_lens.

Two tests are skipped on macOS-arm64 to unblock the MPS CI job; they remain active on Linux:

  • tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py::test_generate_without_tokenizer_stop_at_eos_false_kv_cache
  • tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py::test_generate_return_type_str_without_tokenizer_errors

Reproduction (CI only - have not tried reproduction locally on M1 Mac)

CI matrix where it fires: macos-latest (GitHub Actions managed Apple Silicon) + Python 3.11.9 + PyTorch 2.7.1 + transformers 5.8.1.

import torch
from transformer_lens.model_bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
bridge.tokenizer = None
tokens = torch.tensor([[15496, 11, 314, 1101, 257]], dtype=torch.long)

# NaNs on macOS-arm64 CI; finite locally and on Linux CI:
bridge.generate(
    tokens,
    max_new_tokens=3,
    stop_at_eos=False,
    use_past_kv_cache=True,
    return_type="tokens",
    verbose=False,
)

Direct HF reproduction (no bridge involved):

o0 = bridge.original_model(tokens, use_cache=True)
next_id = o0.logits[:, -1, :].argmax(-1, keepdim=True)

# NaN on macOS-arm64 CI, finite elsewhere:
o1 = bridge.original_model(
    next_id,
    past_key_values=o0.past_key_values,
    use_cache=True,
    attention_mask=torch.ones((1, 6), dtype=torch.long),
    position_ids=torch.tensor([[5]], dtype=torch.long),
)

# Always finite (same effective input):
bridge.original_model(torch.cat([tokens, next_id], dim=1))

Diagnostic evidence

From CI run 26312216246:

[DIAG] step0_logits: nan=False inf=False shape=(1, 5, 50257)            ← step 0 fine
[DIAG] cache_type=DynamicCache seq_len=5 layers=6
[DIAG] cache_layer_0: K_nan=False V_nan=False K_shape=(1, 12, 5, 64)    ← cache contents fine
[DIAG] step1_with_mask_and_pos: nan=True inf=False shape=(1, 1, 50257)  ← fails
[DIAG] step1_full_no_cache: nan=False inf=False shape=(1, 6, 50257)     ← same comp w/o cache: fine

So: cache is valid, kwargs are correct, and the full-recompute path works. Only past_key_values + 1-token input NaNs on the runner.

Environment

Value
Runner macos-latest (Apple Silicon, GitHub Actions)
Python 3.11.9
PyTorch 2.7.1
transformers 5.8.1
transformer_lens dev (commit da370b85 and later)
Model distilgpt2 (GPT2LMHeadModel, 6 layers, 12 heads, d_head=64)
Attention impl eager (forced by bridge at bridge.py:210-213)
Cache type DynamicCache (transformers v5 default)

Investigation already done

  1. Confirmed the bridge wrapper and raw bridge.original_model(...) produce bitwise-identical step-0 logits.
  2. Confirmed bridge forces attn_implementation="eager" at load via sources/transformers.py:586.
  3. Confirmed not missing-kwargs: passing explicit attention_mask=(batch, total_len) and position_ids=(batch, 1) (mirroring HF.generate's contract) does not help.
  4. Confirmed not cache corruption: K and V cached after step 0 are NaN-free.
  5. Confirmed not a Linux-specific path: all Linux CI configs (Python 3.10 / 3.11 / 3.12) pass.
  6. Not reproducible on my M4 Mac (Tahoe-build, both Python 3.11 and 3.12). Strongly suggests runner-image specifics, possibly Apple Silicon CPU generation, macOS version, or Accelerate framework version.
  7. Related upstream issues that share the SDPA-NaN pattern but aren't this exact bug (we're already on eager):

Current workaround

Two failing tests gated on macOS-arm64 via:

import platform

_MACOS_ARM64 = platform.system() == "Darwin" and platform.machine() == "arm64"

@pytest.mark.skipif(_MACOS_ARM64, reason="Upstream macOS-arm64 KV-cache NaN; see #<this-issue>.")

Linux coverage of the same bridge code path is unaffected.

Action items

  • Reproduce on a known-bad runner image (or pin the exact GitHub Actions image and confirm)
  • Determine whether the bug is in transformers' eager_attention_forward cached path or in PyTorch CPU matmul on Apple Silicon
  • File upstream (transformers and/or pytorch) once narrowed to a minimal repro
  • Remove the skipif markers once the upstream bug is fixed

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions