diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index c973087a09..6fcc75aa1f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -427,6 +427,7 @@ fused_qkv: false fused_mlp: false record_internal_nn_metrics: 0 +record_layerwise_hidden_states: false # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b70b7238d3..416a17b072 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1773,6 +1773,7 @@ class Metrics(BaseModel): gcs_metrics: bool = Field(False, description="If True, save metrics to GCS.") save_config_to_gcs: bool = Field(False, description="If True, save config to GCS.") record_internal_nn_metrics: int = Field(0, description="Record internal neural network metrics.") + record_layerwise_hidden_states: bool = Field(False, description="Record layer-by-layer hidden states.") prometheus_port: int = Field(0, description="Port for Prometheus metrics server. 0 disables it.") enable_checkpoint_cloud_logger: bool = Field(False, description="Enables structured logging for checkpointing.") enable_tunix_perf_metrics: bool = Field( diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..bd511a5203 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -257,6 +257,9 @@ def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cac jnp.sum(layer_output == 0) / jnp.size(layer_output), ) + if self.config.record_layerwise_hidden_states: + self.sow(nnx.Intermediate, "layer_output", layer_output) + if self.config.scan_layers: return layer_output, None return layer_output, kv_cache diff --git a/tests/assets/logits_generation/generate_hf_golden_logits.py b/tests/assets/logits_generation/generate_hf_golden_logits.py index c57d58c380..07a2e2c107 100644 --- a/tests/assets/logits_generation/generate_hf_golden_logits.py +++ b/tests/assets/logits_generation/generate_hf_golden_logits.py @@ -71,6 +71,7 @@ def save_golden_logits( trust_remote_code, image_paths, output_format, + record_layerwise_hidden_states=False, ): """save golden logits""" if hf_model_path is None: @@ -142,11 +143,31 @@ def save_golden_logits( input_ids = tokenizer.encode(prompt_text, return_tensors="pt") inputs = {"input_ids": input_ids} + captured_hidden_states = [] + def make_hook_fn(layer_idx): + def hook_fn(module, inputs, output): + tensor = output[0] if isinstance(output, tuple) else output + captured_hidden_states.append((layer_idx, tensor.detach().to(torch.float32).cpu())) + return hook_fn + + hooks = [] + if record_layerwise_hidden_states: + base_model = getattr(model, "model", model) + layers = getattr(base_model, "layers", None) + if layers is None: + raise ValueError(f"Could not find layers in model structure: {model}") + for idx, layer in enumerate(layers): + hooks.append(layer.register_forward_hook(make_hook_fn(idx))) + # 2. Run inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits.cpu().to(torch.float32).numpy() + # Remove hooks + for hook in hooks: + hook.remove() + # 3. Populate final data dictionary with tensors from inputs and logits for key, value in inputs.items(): new_key = "tokens" if key == "input_ids" else key @@ -154,6 +175,12 @@ def save_golden_logits( data_to_save[new_key] = val_np[0] if val_np.ndim > 0 else val_np data_to_save["logits"] = logits[0] + if record_layerwise_hidden_states: + # Sort by layer index to ensure sequential order + captured_hidden_states.sort(key=lambda x: x[0]) + layer_hidden_states_list = [x[1].numpy()[0] for x in captured_hidden_states] + data_to_save["layer_hidden_states"] = layer_hidden_states_list + print(f"Token length is {len(data_to_save['tokens'])} for prompt: {prompt_text}") print(f"raw ids: {data_to_save['tokens']}") @@ -162,6 +189,8 @@ def save_golden_logits( for key, value in data_to_save.items(): if isinstance(value, np.ndarray): data_to_save[key] = value.tolist() + elif key == "layer_hidden_states" and isinstance(value, list): + data_to_save[key] = [x.tolist() if isinstance(x, np.ndarray) else x for x in value] all_data_to_save.append(data_to_save) @@ -220,6 +249,11 @@ def main(raw_args=None) -> None: default="json", help="The output format for the golden logits. (json, pickle)", ) + parser.add_argument( + "--record-layerwise-hidden-states", + action="store_true", + help="Record layer-by-layer intermediate hidden states.", + ) args = parser.parse_args(raw_args) prompts = args.prompts.split(";") image_paths = args.image_paths.split(";") if args.image_paths else [] @@ -240,6 +274,7 @@ def main(raw_args=None) -> None: args.trust_remote_code, image_paths, args.output_format, + args.record_layerwise_hidden_states, ) diff --git a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 01956268af..beb61a67a7 100644 --- a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -9,7 +9,6 @@ # Example Usage: export HF_TOKEN=; export BASE_OUTPUT_PATH=; bash test_deepseek.sh -# The golden logit can be generated by: # python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --trust-remote-code=False set -ex @@ -62,7 +61,7 @@ if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} fi -python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6 +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6 --compare_layerwise_logits # Run pre-training - tokamax_gmm implementation python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index 92adb4e921..f606a97298 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -246,6 +246,8 @@ def get_data(golden_data_point, config): def main(config, test_args): # pylint: disable=W0621 """Test the Whole Model of model_name""" + if test_args.compare_layerwise_hidden_states: + config.record_layerwise_hidden_states = True init_rng = jax.random.PRNGKey(config.init_weights_seed) init_rng, rng1 = jax.random.split(init_rng) devices_array = maxtext_utils.create_device_mesh(config) @@ -291,6 +293,14 @@ def main(config, test_args): # pylint: disable=W0621 for golden_data_index, golden_data_point in enumerate(golden_data): max_logging.log(f"\n--- Comparing forward pass for golden data index: {golden_data_index} ---") ids, decoder_segment_ids, decoder_positions, golden_logits, seq_len, images = get_data(golden_data_point, config) + if test_args.compare_layerwise_hidden_states: + if "layer_hidden_states" not in golden_data_point: + raise KeyError( + "golden data point is missing 'layer_hidden_states' key, but --compare_layerwise_hidden_states is enabled." + ) + # golden_layer_hidden_states shape: list of [seq_len, ...] + golden_layer_hidden_states = [np.asarray(x, dtype=np.float32) for x in golden_data_point["layer_hidden_states"]] + max_logging.log("maxtext forward pass") if state is None: full_train_logits = model( @@ -301,15 +311,27 @@ def main(config, test_args): # pylint: disable=W0621 enable_dropout=False, ) else: - full_train_logits = model.apply( - state.params, - ids, - decoder_positions, - decoder_segment_ids, - encoder_images=images, - enable_dropout=False, - rngs={"aqt": init_rng}, - ) + if test_args.compare_layerwise_hidden_states: + full_train_logits, intermediate_outputs = model.apply( + state.params, + ids, + decoder_positions, + decoder_segment_ids, + encoder_images=images, + enable_dropout=False, + rngs={"aqt": init_rng}, + mutable=["intermediates"], + ) + else: + full_train_logits = model.apply( + state.params, + ids, + decoder_positions, + decoder_segment_ids, + encoder_images=images, + enable_dropout=False, + rngs={"aqt": init_rng}, + ) full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits, tiled=True) # if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size] @@ -319,65 +341,156 @@ def main(config, test_args): # pylint: disable=W0621 full_train_logits = full_train_logits[:, :seq_len, :] token_size = int(test_args.token_size) if test_args.token_size else seq_len - if full_train_logits.shape[-1] != golden_logits.shape[-1]: - max_logging.log( - f"Vocab size mismatch: train logits vocab size {full_train_logits.shape[-1]}, " - f"golden logits vocab size {golden_logits.shape[-1]}. " - "Comparing up to the smaller vocab size." - ) - min_vocab_size = min(full_train_logits.shape[-1], golden_logits.shape[-1]) - start_index = 1 if test_args.skip_first_token else 0 - # shape [seq_len, vocab_size] - train_logits_slice = full_train_logits[0, start_index:token_size, :min_vocab_size] - golden_logits_slice = golden_logits[start_index:token_size, :min_vocab_size] - - if train_logits_slice.shape[0] > 2: - max_logging.log(f"\n[logits: token {start_index + 2}]") - max_logging.log(f"{golden_logits_slice[2]=}") - max_logging.log(f"{train_logits_slice[2]=}") - - # Calculate absolute and relative differences for detailed reporting - abs_diff = jnp.abs(train_logits_slice - golden_logits_slice) - - # To avoid division by zero, add a small epsilon where golden_logits_slice is zero - safe_golden_logits = jnp.where(golden_logits_slice == 0, 1e-8, golden_logits_slice) - rel_diff = abs_diff / jnp.abs(safe_golden_logits) - - max_abs_diff_idx = jnp.unravel_index(jnp.argmax(abs_diff), abs_diff.shape) - max_rel_diff_idx = jnp.unravel_index(jnp.argmax(rel_diff), rel_diff.shape) - - max_abs_diff_val = abs_diff[max_abs_diff_idx] - max_rel_diff_val = rel_diff[max_rel_diff_idx] - msg = ( - "\n[numerical difference]\n" - f"Max absolute difference: {max_abs_diff_val:.4e} at index {max_abs_diff_idx}\n" - f" (Train: {train_logits_slice[max_abs_diff_idx]:.4e}, Golden: {golden_logits_slice[max_abs_diff_idx]:.4e})\n" - f"Max relative difference: {max_rel_diff_val:.4e} at index {max_rel_diff_idx}\n" - f" (Train: {train_logits_slice[max_rel_diff_idx]:.4e}, Golden: {golden_logits_slice[max_rel_diff_idx]:.4e})" - ) - max_logging.log(msg) - if test_args.clip_logits_epsilon is not None: - model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), min=test_args.clip_logits_epsilon) - golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), min=test_args.clip_logits_epsilon) - else: - model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1) - golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1) - - if golden_probabilities.shape[0] > 1: - max_logging.log(f"\n[probability: token {start_index + 1}]") - max_logging.log(f"{golden_probabilities[1]=}") - max_logging.log(f"{model_probabilities[1]=}") - - kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1) - max_kl_div_val = jax.numpy.max(kl_div) - max_kl_div_idx = jax.numpy.argmax(kl_div) - max_logging.log( - f"\n[KL divergence]\n" - f"KL divergence = {kl_div}, max KL divergence = {max_kl_div_val} at index {max_kl_div_idx}, " - f"the corresponding token id is {ids[0, max_kl_div_idx + start_index]}" - ) + def get_layer_index(path_keys, config): + for key in path_keys: + if not isinstance(key, str): + continue + if key.startswith("dense_layers_"): + idx = int(key.split("_")[-1]) + return idx + elif key.startswith("moe_layers_"): + idx = int(key.split("_")[-1]) + return config.first_num_dense_layers + idx + elif key.startswith("layers_"): + idx = int(key.split("_")[-1]) + return idx + elif key == "dense_layers": + return 0 + elif key == "moe_layers": + return config.first_num_dense_layers + elif key == "layers": + return 0 + return 9999 + + def compare_and_assert(train_data, golden_data, name="final", is_logits=True): + if train_data.shape[-1] != golden_data.shape[-1]: + max_logging.log( + f"[{name}] Dimension mismatch: train {train_data.shape[-1]}, " + f"golden {golden_data.shape[-1]}. " + "Comparing up to the smaller size." + ) + min_last_dim = min(train_data.shape[-1], golden_data.shape[-1]) + + # Dynamically build slice to handle arbitrary intermediate dimensions (e.g. hyper-connections) + # We slice batch index 0, token range, and the last dimension. + slice_obj = [0, slice(start_index, token_size)] + for _ in range(train_data.ndim - 3): + slice_obj.append(slice(None)) + slice_obj.append(slice(None, min_last_dim)) + slice_tuple = tuple(slice_obj) + + train_data_slice = train_data[slice_tuple] + golden_data_slice = golden_data[slice_tuple[1:]] + + if train_data_slice.shape[0] > 2: + max_logging.log(f"\n[{name} values: token {start_index + 2}]") + max_logging.log(f"{golden_data_slice[2]=}") + max_logging.log(f"{train_data_slice[2]=}") + + # Calculate absolute and relative differences for detailed reporting + abs_diff = jnp.abs(train_data_slice - golden_data_slice) + + # To avoid division by zero, add a small epsilon where golden_data_slice is zero + safe_golden_data = jnp.where(golden_data_slice == 0, 1e-8, golden_data_slice) + rel_diff = abs_diff / jnp.abs(safe_golden_data) + + max_abs_diff_idx = jnp.unravel_index(jnp.argmax(abs_diff), abs_diff.shape) + max_rel_diff_idx = jnp.unravel_index(jnp.argmax(rel_diff), rel_diff.shape) + + max_abs_diff_val = abs_diff[max_abs_diff_idx] + max_rel_diff_val = rel_diff[max_rel_diff_idx] + msg = ( + f"\n[{name} numerical difference]\n" + f"Max absolute difference: {max_abs_diff_val:.4e} at index {max_abs_diff_idx}\n" + f" (Train: {train_data_slice[max_abs_diff_idx]:.4e}, Golden: {golden_data_slice[max_abs_diff_idx]:.4e})\n" + f"Max relative difference: {max_rel_diff_val:.4e} at index {max_rel_diff_idx}\n" + f" (Train: {train_data_slice[max_rel_diff_idx]:.4e}, Golden: {golden_data_slice[max_rel_diff_idx]:.4e})" + ) + max_logging.log(msg) + + if is_logits: + if test_args.clip_logits_epsilon is not None: + model_probabilities = jnp.clip( + jax.nn.softmax(train_data_slice, axis=-1), min=test_args.clip_logits_epsilon + ) + golden_probabilities = jnp.clip( + jax.nn.softmax(golden_data_slice, axis=-1), min=test_args.clip_logits_epsilon + ) + else: + model_probabilities = jax.nn.softmax(train_data_slice, axis=-1) + golden_probabilities = jax.nn.softmax(golden_data_slice, axis=-1) + + if golden_probabilities.shape[0] > 1: + max_logging.log(f"\n[{name} probability: token {start_index + 1}]") + max_logging.log(f"{golden_probabilities[1]=}") + max_logging.log(f"{model_probabilities[1]=}") + + kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1) + max_kl_div_val = jax.numpy.max(kl_div) + max_kl_div_idx = jax.numpy.argmax(kl_div) + max_logging.log( + f"\n[{name} KL divergence]\n" + f"KL divergence = {kl_div}, max KL divergence = {max_kl_div_val} at index {max_kl_div_idx}, " + f"the corresponding token id is {ids[0, max_kl_div_idx + start_index]}" + ) + + if test_args.atol is not None: + max_logging.log(f"\n[{name} test criteria]") + max_logging.log( + f"Checking Numerical Differences between train and golden against " + f"atol={test_args.atol} rtol={test_args.rtol}." + ) + rtol_val = float(test_args.rtol) + atol_val = float(test_args.atol) + assert jax.numpy.allclose( + train_data_slice, golden_data_slice, rtol=rtol_val, atol=atol_val, equal_nan=False + ), f"[{name}] Values do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." + + if is_logits and test_args.max_kl_div is not None: + max_logging.log( + f"Checking KL Divergence between train distribution and golden distribution against " + f"threshold {test_args.max_kl_div}." + ) + assert jax.numpy.all( + kl_div < test_args.max_kl_div, + ), ( + f"[{name}] KL divergence values exceed the specified threshold of {test_args.max_kl_div}. " + f"Max divergence: {jax.numpy.max(kl_div)}" + ) + + compare_and_assert(full_train_logits, golden_logits, "final") + + all_layer_hidden_states_processed = [] + if test_args.compare_layerwise_hidden_states and state is not None: + sowed_leaves = [] + for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs): + path_keys = [k.key for k in path if hasattr(k, "key")] + if "layer_output" in path_keys: + idx = get_layer_index(path_keys, config) + sowed_leaves.append((idx, val)) + + # Sort by index to match sequential layer order + sowed_leaves.sort(key=lambda item: item[0]) + + all_layer_activations = [] + for _, val in sowed_leaves: + if val.ndim == 3: + all_layer_activations.append(val) + elif val.ndim == 4: + for j in range(val.shape[0]): + all_layer_activations.append(val[j]) + + for i, layer_act in enumerate(all_layer_activations): + # Gather activations from all hosts + layer_act = jax.experimental.multihost_utils.process_allgather(layer_act, tiled=True) + # Slice to sequence length. Shapes can be [num_hosts * batch, seq_len, d] + # or [num_hosts * batch, seq_len, hc_mult, d] for hyper-connections + layer_act = layer_act[:, :seq_len, ...] + all_layer_hidden_states_processed.append(layer_act) + + compare_and_assert(layer_act, golden_layer_hidden_states[i], f"layer_{i}", is_logits=False) if jax.process_index() == 0 and test_args.output_logits_path: data_to_save = { @@ -385,32 +498,10 @@ def main(config, test_args): # pylint: disable=W0621 "tokens": ids[0, :seq_len].tolist(), "logits": full_train_logits[0].tolist(), } + if test_args.compare_layerwise_hidden_states: + data_to_save["layer_hidden_states"] = [x[0].tolist() for x in all_layer_hidden_states_processed] all_data_to_save.append(data_to_save) - if test_args.atol is not None: - max_logging.log("\n[test criteria]") - max_logging.log( - f"Checking Numerical Differences between train logits and golden logits against " - f"atol={test_args.atol} rtol={test_args.rtol}." - ) - rtol_val = float(test_args.rtol) - atol_val = float(test_args.atol) - assert jax.numpy.allclose( - train_logits_slice, golden_logits_slice, rtol=rtol_val, atol=atol_val, equal_nan=False - ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." - - if test_args.max_kl_div is not None: - max_logging.log( - f"Checking KL Divergence between train distribution and golden distribution against " - f"threshold {test_args.max_kl_div}." - ) - assert jax.numpy.all( - kl_div < test_args.max_kl_div, - ), ( - f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. " - f"Max divergence: {jax.numpy.max(kl_div)}" - ) - else: """Comparing maxtext model with HF model on-the-fly""" if test_args.hf_model_path == "": @@ -583,6 +674,13 @@ def main(config, test_args): # pylint: disable=W0621 parser.add_argument("--golden_logits_path", type=str, required=False, default="") parser.add_argument("--hf_model_path", type=str, required=False, default="") parser.add_argument("--run_hf_model", type=bool, required=False, default=False) + parser.add_argument( + "--compare_layerwise_hidden_states", + action="store_true", + required=False, + default=False, + help="Compare layer-by-layer intermediate hidden states against golden data.", + ) parser.add_argument("--output_logits_path", type=str, required=False, default="") parser.add_argument("--gcs_output_logits_path", type=str, required=False, default="") parser.add_argument("--clip_logits_epsilon", type=float, required=False, default=None)