Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ fused_qkv: false
fused_mlp: false

record_internal_nn_metrics: 0
record_layerwise_hidden_states: false

# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,7 @@ class Metrics(BaseModel):
gcs_metrics: bool = Field(False, description="If True, save metrics to GCS.")
save_config_to_gcs: bool = Field(False, description="If True, save config to GCS.")
record_internal_nn_metrics: int = Field(0, description="Record internal neural network metrics.")
record_layerwise_hidden_states: bool = Field(False, description="Record layer-by-layer hidden states.")
prometheus_port: int = Field(0, description="Port for Prometheus metrics server. 0 disables it.")
enable_checkpoint_cloud_logger: bool = Field(False, description="Enables structured logging for checkpointing.")
enable_tunix_perf_metrics: bool = Field(
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cac
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)

if self.config.record_layerwise_hidden_states:
self.sow(nnx.Intermediate, "layer_output", layer_output)

if self.config.scan_layers:
return layer_output, None
return layer_output, kv_cache
Expand Down
35 changes: 35 additions & 0 deletions tests/assets/logits_generation/generate_hf_golden_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def save_golden_logits(
trust_remote_code,
image_paths,
output_format,
record_layerwise_hidden_states=False,
):
"""save golden logits"""
if hf_model_path is None:
Expand Down Expand Up @@ -142,18 +143,44 @@ def save_golden_logits(
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
inputs = {"input_ids": input_ids}

captured_hidden_states = []
def make_hook_fn(layer_idx):
def hook_fn(module, inputs, output):
tensor = output[0] if isinstance(output, tuple) else output
captured_hidden_states.append((layer_idx, tensor.detach().to(torch.float32).cpu()))
return hook_fn

hooks = []
if record_layerwise_hidden_states:
base_model = getattr(model, "model", model)
layers = getattr(base_model, "layers", None)
if layers is None:
raise ValueError(f"Could not find layers in model structure: {model}")
for idx, layer in enumerate(layers):
hooks.append(layer.register_forward_hook(make_hook_fn(idx)))

# 2. Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu().to(torch.float32).numpy()

# Remove hooks
for hook in hooks:
hook.remove()

# 3. Populate final data dictionary with tensors from inputs and logits
for key, value in inputs.items():
new_key = "tokens" if key == "input_ids" else key
val_np = value.cpu().numpy()
data_to_save[new_key] = val_np[0] if val_np.ndim > 0 else val_np
data_to_save["logits"] = logits[0]

if record_layerwise_hidden_states:
# Sort by layer index to ensure sequential order
captured_hidden_states.sort(key=lambda x: x[0])
layer_hidden_states_list = [x[1].numpy()[0] for x in captured_hidden_states]
data_to_save["layer_hidden_states"] = layer_hidden_states_list

print(f"Token length is {len(data_to_save['tokens'])} for prompt: {prompt_text}")
print(f"raw ids: {data_to_save['tokens']}")

Expand All @@ -162,6 +189,8 @@ def save_golden_logits(
for key, value in data_to_save.items():
if isinstance(value, np.ndarray):
data_to_save[key] = value.tolist()
elif key == "layer_hidden_states" and isinstance(value, list):
data_to_save[key] = [x.tolist() if isinstance(x, np.ndarray) else x for x in value]

all_data_to_save.append(data_to_save)

Expand Down Expand Up @@ -220,6 +249,11 @@ def main(raw_args=None) -> None:
default="json",
help="The output format for the golden logits. (json, pickle)",
)
parser.add_argument(
"--record-layerwise-hidden-states",
action="store_true",
help="Record layer-by-layer intermediate hidden states.",
)
args = parser.parse_args(raw_args)
prompts = args.prompts.split(";")
image_paths = args.image_paths.split(";") if args.image_paths else []
Expand All @@ -240,6 +274,7 @@ def main(raw_args=None) -> None:
args.trust_remote_code,
image_paths,
args.output_format,
args.record_layerwise_hidden_states,
)


Expand Down
3 changes: 1 addition & 2 deletions tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

# Example Usage: export HF_TOKEN=<huggingface_access_token>; export BASE_OUTPUT_PATH=<GCS_bucket_path>; bash test_deepseek.sh

# The golden logit can be generated by:
# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --trust-remote-code=False

set -ex
Expand Down Expand Up @@ -62,7 +61,7 @@ if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
fi

python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6
python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6 --compare_layerwise_logits

# Run pre-training - tokamax_gmm implementation
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4
Expand Down
Loading
Loading