Skip to content

Document None defaults for MultiHeadAttention args#5470

Open
adityasingh2400 wants to merge 1 commit into
google:mainfrom
adityasingh2400:doc-multihead-attention-none-defaults-4182
Open

Document None defaults for MultiHeadAttention args#5470
adityasingh2400 wants to merge 1 commit into
google:mainfrom
adityasingh2400:doc-multihead-attention-none-defaults-4182

Conversation

@adityasingh2400
Copy link
Copy Markdown

Summary

Several MultiHeadAttention attributes accept None, but the docstring did not say what happens when you pass it. This fills in the gaps for the four args called out in #4182:

  • qkv_features: defaults to inputs_q.shape[-1] (linen) / in_features (nnx).
  • out_features: defaults to inputs_q.shape[-1] (linen) / in_features (nnx).
  • deterministic: falls back to the deterministic argument passed to __call__.
  • precision: uses jax.lax.Precision.DEFAULT.

Applied symmetrically to flax/linen/attention.py (MultiHeadDotProductAttention and its MultiHeadAttention alias) and the NNX twin in flax/nnx/nn/attention.py.

Pure docstring change, no behavior change.

Fixes #4182

Test plan

  • python3 -c 'import ast; ast.parse(...)' on both files (parses clean).
  • Style matches surrounding entries in each docstring (linen uses sentence case, nnx uses lower case).

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 21, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Several MultiHeadAttention args accept None, but the docstring did
not say what happens in that case. Spell it out for qkv_features,
out_features (both default to inputs_q.shape[-1]), deterministic
(falls back to the __call__ kwarg of the same name), and precision
(uses lax.Precision.DEFAULT). Apply the same expansion to the NNX
twin in flax/nnx/nn/attention.py.

Fixes google#4182
@adityasingh2400 adityasingh2400 force-pushed the doc-multihead-attention-none-defaults-4182 branch from 9cd83f8 to 0581162 Compare May 22, 2026 00:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MultiHeadAttention documentation missing descriptions of None values for optional arguments

1 participant