diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index bff36397af..e3f2785fde 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -252,7 +252,7 @@ def __call__( # Ensure inputs are at least 2D with a batch dimension input_ids = jnp.expand_dims(input_ids, axis=1) - input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1) + input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=-1) with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): aux_hidden_states = []