Skip to content

Document LSTMCell carry as the (c, h) tuple it actually is#5469

Open
adityasingh2400 wants to merge 1 commit into
google:mainfrom
adityasingh2400:doc-lstmcell-carry-tuple-4124
Open

Document LSTMCell carry as the (c, h) tuple it actually is#5469
adityasingh2400 wants to merge 1 commit into
google:mainfrom
adityasingh2400:doc-lstmcell-carry-tuple-4124

Conversation

@adityasingh2400
Copy link
Copy Markdown

Summary

The LSTMCell.__call__ docstring described carry only as "the hidden state of the LSTM cell", which is misleading: carry is actually a tuple (c, h) of the cell state and the hidden state, as initialize_carry and the implementation both make clear (return (c, h) and c, h = carry).

This PR spells out the contract in the docstring:

carry: a tuple (c, h) of the cell state c and the hidden state h, both of shape (*batch, features). Typically initialized using LSTMCell.initialize_carry.

Applied to all four affected __call__ docstrings for parity:

  • flax/linen/recurrent.py LSTMCell
  • flax/linen/recurrent.py OptimizedLSTMCell
  • flax/nnx/nn/recurrent.py LSTMCell
  • flax/nnx/nn/recurrent.py OptimizedLSTMCell

The OptimizedLSTMCell docstrings previously pointed at LSTMCell.initialize_carry; those references are also updated to OptimizedLSTMCell.initialize_carry to match the class users actually have a handle to.

Fixes #4124.

Test plan

  • Pure docstring change, no runtime behavior modified.
  • No public API changes.
  • Wording matches the implementation (c, h = carry, return (new_c, new_h), new_h) and initialize_carry shapes.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 21, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

The LSTMCell.__call__ docstring described 'carry' as 'the hidden state
of the LSTM cell', leaving users to infer that it is actually a
tuple of cell state and hidden state both of shape (*batch, features),
typically created via LSTMCell.initialize_carry. Spell that contract
out in both the Linen and NNX twins, and mirror the same docstring
for OptimizedLSTMCell.

Fixes google#4124
@adityasingh2400 adityasingh2400 force-pushed the doc-lstmcell-carry-tuple-4124 branch from b63a4cd to 386bf30 Compare May 22, 2026 00:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Clarification for LSTMCell Documentation

1 participant