Skip to content

Add layer by layer hidden state testing support to forward_pass_logit_checker.py#4173

Draft
snehalv2002 wants to merge 1 commit into
mainfrom
snehalv-dsv4-layer-analysis
Draft

Add layer by layer hidden state testing support to forward_pass_logit_checker.py#4173
snehalv2002 wants to merge 1 commit into
mainfrom
snehalv-dsv4-layer-analysis

Conversation

@snehalv2002

Copy link
Copy Markdown
Collaborator

Description

forward_pass_logit_checker.py now provides the option to compute the layer by layer diff for a model. It's quite common to have to debug issues in new model bringup by inspecting the layer by layer outputs so this change can save developers from writing individual implementations of the same thing. We use Flax Linen nn.Module.sow() which offers a native way to save intermediate tensors and is already in use around maxtext for saving things like moe load balancing losses. Currently we add support for the DeepSeek V2 16B model.

How to Use the Verification Tool

Step 1: Generate Golden Reference Hidden States

Use generate_hf_golden_logits.py on a machine with PyTorch/Hugging Face installed to dump reference hidden states for the target prompts:

python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
    --model-id=<hf_model_id_or_path> \
    --output-path=<output_jsonl_path> \
    --prompts="I love to;Today is a;What is the" \
    --record-layerwise-hidden-states # new flag \
    --trust-remote-code=True

Step 2: Run MaxText Verification Checker

Run the verifier pointing to the generated golden dataset. Add the --compare_layerwise_hidden_states flag to enable
layerwise assertion:

python3 -m tests.utils.forward_pass_logit_checker \
    src/maxtext/configs/base.yml \
    base_output_directory=<output_dir> \
    load_parameters_path=<maxtext_model_checkpoint_path> \
    model_name=<model_name> \
    --golden_logits_path=<output_jsonl_path> \
    --atol=1e-4 \
    --rtol=1e-4 \
    --compare_layerwise_hidden_states  # new flag

How to Add Support for a New Model

1. Prerequisites in MaxText

Ensure the JAX model implementation sows its raw block outputs. In your model class (e.g. deepseek.py or your new model's block layers), add the following sowing logic right before returning the layer output:

if self.config.record_layerwise_hidden_states:
  self.sow(nnx.Intermediate, "layer_output", layer_output)

Additionally, update get_layer_index in tests/utils/forward_pass_logit_checker.py to correctly map the sowed PyTree keys of your new model to linear layer indices.

2. Reference Hidden States Generation

The generator script generate_hf_golden_logits.py uses PyTorch forward hooks to capture hidden states. Depending on how the model is supported by Hugging Face, follow the appropriate case:

Case A: Model is Natively Supported by HF Transformers (e.g. Llama, Mistral, Gemma)

Zero changes are required in the generator script.
• The script uses AutoModelForCausalLM to load the model.
• The hook registration logic dynamically resolves standard repository structures by looking for model.model.layers or
model.layers . Since natively supported models follow this structure, hooks will register automatically.
• Simply run the generator command (Step 1) passing the Hugging Face model ID.

Case B: Model is a Custom Remote Model (e.g. DeepSeek-V4, custom architectures)

If the model uses custom codebase files downloaded dynamically from the Hub (via trust_remote_code=True ), ensure the following:

  1. Layers Property Resolution: Verify if the custom model class exposes its decoder blocks in a list or ModuleList under model.layers or model.model.layers. If the custom model uses a non-standard variable name (e.g. model.transformer.blocks), update the hook registration block in generate_hf_golden_logits.py to add a fallback lookup:
layers = getattr(base_model, "layers", None)
if layers is None:
   layers = getattr(base_model, "custom_blocks_name", None) # Add fallback
  1. Layer Return Types: The hook intercepts the layer module output. Ensure that each layer module's forward returns either:
    • The raw activation tensor directly.
    • A tuple where the first element ( output[0] ) is the hidden state tensor (this is standard in Hugging Face to return
    attention weights alongside activations).
    • The script already automatically unpacks tuples.
  2. Loading Workaround (If applicable): If the custom model suffers from slow CPU memory initialization or bugs in transformers' automatic weight loading, you can wrap the loading section in generate_hf_golden_logits.py using from_config and manually map the state dict files (safetensors) from a locally cloned copy of the repository.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

You can also provide a comma-separated list. If you don't want to close a bug but
simply to reference it, use BUGS, e.g.:
BUGS: b/123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/models/deepseek.py 0.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant