Skip to content

Add batched productmap#280

Merged
mj023 merged 22 commits intomainfrom
batched_vmap
Apr 7, 2026
Merged

Add batched productmap#280
mj023 merged 22 commits intomainfrom
batched_vmap

Conversation

@mj023
Copy link
Copy Markdown
Collaborator

@mj023 mj023 commented Mar 20, 2026

Problem

Sometimes running a model is not possible because of memory restrictions. The nested vmaps can lead to JAX creating large arrays for intermediate results, that can be temporarily saved in the GPU memory. Usually these arrays have the dimensions of the State-Action-Space of the model, so looping over batches of half the grid size along one of its dimensions can already halve the peak memory usage. The batching comes at a cost though, the execution time will get progressively worse the smaller the batch size. For big batches the drop in speed is bigger than I would have expected, given that not all the computations can happen at the same time anyways.

New feature

This PR implements a batched version of productmap. The user can for each grid specify the batch size for each states grid. Instead of using vmap to map the Q_and_F_Function along this grid, jax.lax.map will be used, which will then either loop over the batches of gridpoints or if batch_size=0, work like vmap. The batched version will only be used during the solution, as the State-Action-Space for the simulation is already much smaller, as it only depends on the number of simulated subjects.

Tasks

  • Add batched productmap
  • Refactor so not two versions of _base_productmap are needed
  • Fix typing, tests
  • Investigate speed drop

@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community bot commented Mar 20, 2026

Documentation build overview

📚 pylcm | 🛠️ Build #32159082 | 📁 Comparing 253393a against latest (4c70a64)

  🔍 Preview build  

Show files changed (32 files in total): 📝 32 modified | ➕ 0 added | ➖ 0 deleted
File Status
index.html 📝 modified
approximating-continuous-shocks/index.html 📝 modified
benchmarking/index.html 📝 modified
benchmarking-1/index.html 📝 modified
beta-delta/index.html 📝 modified
conventions/index.html 📝 modified
debugging/index.html 📝 modified
defining-models/index.html 📝 modified
dispatchers/index.html 📝 modified
function-representation/index.html 📝 modified
grids/index.html 📝 modified
index-1/index.html 📝 modified
index-2/index.html 📝 modified
index-3/index.html 📝 modified
index-4/index.html 📝 modified
installation/index.html 📝 modified
interpolation/index.html 📝 modified
mahler-yum-2024/index.html 📝 modified
mortality/index.html 📝 modified
pandas-interop/index.html 📝 modified
parameters/index.html 📝 modified
precautionary-savings/index.html 📝 modified
precautionary-savings-health/index.html 📝 modified
regimes/index.html 📝 modified
setup/index.html 📝 modified
shocks/index.html 📝 modified
solving-and-simulating/index.html 📝 modified
stochastic-transitions/index.html 📝 modified
tiny/index.html 📝 modified
tiny-example/index.html 📝 modified
transitions/index.html 📝 modified
write-economics/index.html 📝 modified

@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 30, 2026

Benchmark results (HEAD only — no baseline comparison available)

Benchmark (253393a) Statistic Value
Mahler-Yum execution time 3.592 s
peak GPU mem 262 MB
compilation time 117.43 s
peak CPU mem 2.24 GB
Mortality execution time 235.0 ms
peak GPU mem 542 MB
compilation time 10.59 s
peak CPU mem 1.25 GB
Precautionary Savings - Solve execution time 29.1 ms
peak GPU mem 8 MB
compilation time 5.06 s
peak CPU mem 1.06 GB
Precautionary Savings - Simulate execution time 138.4 ms
peak GPU mem 138 MB
compilation time 7.03 s
peak CPU mem 1.20 GB
Precautionary Savings - Solve & Simulate execution time 171.7 ms
peak GPU mem 565 MB
compilation time 11.52 s
peak CPU mem 1.22 GB
Precautionary Savings - Solve & Simulate (irreg) execution time 295.4 ms
peak GPU mem 2.18 GB
compilation time 12.27 s
peak CPU mem 1.27 GB

No merge-base results found locally. Run benchmarks on main first for a comparison.

@mj023
Copy link
Copy Markdown
Collaborator Author

mj023 commented Mar 30, 2026

I wasn't yet able to create a good benchmark for a model where batching actually helps. It needs a model that is sufficiently complex, so the compiler can't optimize it well, but it needs to still be runnable somewhat quickly. If using batches actually helps seems to really depend on the model. With Marvins retirement model it nearly cut the memory usage in half, for the Mahler & Yum Model it does very little.

I also fixed an error in the MY model input creation and removed one of the tests, because productmap can now handle scalar inputs.

@mj023 mj023 requested a review from hmgaudecker March 30, 2026 23:52
@hmgaudecker
Copy link
Copy Markdown
Member

Code review

Found 7 issues:

  1. batch_size field leaks into _ShockGrid.paramsContinuousGrid.batch_size is a dataclass field, but _ShockGrid._param_field_names only excludes "n_points". This means every shock grid's .params dict includes batch_size: 0 as if it were a distribution parameter, violating the API contract that .params contains "distribution's parameters' names to their specified values."

"""Get the gridpoints used for discretization.
Returns NaN of the correct shape when required params are missing (i.e., will
only be passed at runtime).
"""
if not self.is_fully_specified:
return jnp.full(self.n_points, jnp.nan)
return self.compute_gridpoints(**self.params)
def get_transition_probs(self) -> FloatND:
"""Get the transition probabilities at the gridpoints.
Returns NaN of the correct shape when required params are missing (i.e., will

  1. Grid base class missing batch_size attribute — processing.py calls grid.batch_size on MappingProxyType[str, Grid], but Grid has no batch_size. The # ty: ignore[unresolved-attribute] suppresses the type error rather than fixing the root cause. Adding batch_size as an abstract property or concrete field on Grid would make the type system enforce the contract.

Q_and_F=Q_and_F,
batch_sizes={name: grid.batch_size for name, grid in all_grids.items()}, # ty: ignore[unresolved-attribute]
action_names=state_action_space.action_names,

  1. batch_sizes parameter missing from get_max_Q_over_a docstring Args section (AGENTS.md says "Google-style docstrings" with all parameters documented)

def get_max_Q_over_a(
*,
Q_and_F: Callable[..., tuple[FloatND, BoolND]],
batch_sizes: dict[str, int],
action_names: tuple[str, ...],
state_names: tuple[str, ...],
) -> MaxQOverAFunction:
r"""Get the function returning the maximum of Q over all actions.
The state-action value function $Q$ is defined as:
```{math}
Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]),
```
with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that
is pre-implemented in LCM).
Fixing a state, maximizing over all feasible actions,
we get the $V$ function:
```{math}
V(x) = \max_{a} Q(x, a).
```
This last step is handled by the function returned here.
Args:
Q_and_F: A function that takes a state-action combination and returns the action
value of that combination and whether the state-action combination is
feasible.
action_names: Tuple of action variable names.
state_names: Tuple of state names.
Returns:
V, i.e., the function that calculates the maximum of the Q-function over all
feasible actions.

  1. Multiple stale vmap references in dispatchers.py — docstrings say "Like vmap" and "arguments over which we apply vmap", comments say "just vmap over all vars", and the result variable is named vmapped, but the implementation uses jax.lax.map. The _base_productmap_batched docstring also says "Cannot have keyword-only arguments" but the implementation actively handles them via **kwargs closures.

"""Map func over the Cartesian product of product_axes and execute in batches.
Like vmap, this function does not preserve the function signature.
Args:
func: The function to be dispatched. Cannot have keyword-only arguments.
product_axes: Tuple with names of arguments over which we apply vmap.
batch_sizes: Dict with the batch sizes for each product_axis.
Returns:
A callable with the same arguments as func. See `product_map` for details.
"""

  1. No tests for batch_size > 0 — the core new feature (memory-saving batched execution) has zero tests verifying correct numerical output when batch_size > 0. All test changes only use batch_size=0.

  2. Action grid batch_size values are silently ignored — processing.py builds batch_sizes from all grids (states + actions), but get_max_Q_over_a only passes batch_sizes to the outer productmap over state_names (line 84), not the inner productmap over action_names (line 64). Similarly, simulation_spacemap hardcodes batch_sizes=dict.fromkeys(action_names, 0). Users setting batch_size on action grids would see no effect.

extra_param_names = _get_extra_param_names(
Q_and_F=Q_and_F, action_names=action_names, state_names=state_names
)
Q_and_F = productmap(
func=Q_and_F,
variables=action_names,
)
@with_signature(
args=["next_regime_to_V_arr", *action_names, *state_names, *extra_param_names],
return_annotation="FloatND",
enforce=False,
)
def max_Q_over_a(
next_regime_to_V_arr: MappingProxyType[RegimeName, FloatND],
**states_actions_params: Array,
) -> FloatND:
Q_arr, F_arr = Q_and_F(
next_regime_to_V_arr=next_regime_to_V_arr,
**states_actions_params,
)
return Q_arr.max(where=F_arr, initial=-jnp.inf)
return productmap(func=max_Q_over_a, variables=state_names, batch_sizes=batch_sizes)

  1. Stale cross-reference in docstring — _base_productmap_batched Returns section says "See product_map for details" but the function is named productmap (no underscore).

batch_sizes: Dict with the batch sizes for each product_axis.
Returns:
A callable with the same arguments as func. See `product_map` for details.

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

hmgaudecker and others added 10 commits April 7, 2026 10:15
- Exclude batch_size from _ShockGrid._param_field_names (and subclass
  overrides) so it no longer leaks into .params as a distribution parameter
- Add batch_size property to Grid base class, removing ty:ignore workaround
- Add batch_sizes to get_max_Q_over_a docstring Args section
- Fix stale vmap references in dispatchers.py: rename vmapped → mapped,
  update docstrings/comments to reference jax.lax.map, fix product_map →
  productmap typo, fix "axe" → "axis"
- Add parametrized test verifying batch_size > 0 produces identical results
- Reject non-zero batch_size on action grids with a clear error message
- Filter batch_sizes to state names only in processing.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Wrap dict.__getitem__ in lambda for _LevelMapping.label_to_index
  (ty requires Callable[[object], int], not bound method)
- Remove stale ty:ignore[no-matching-overload] in result.py
- Add ty:ignore[invalid-assignment] for pandas StringArray subscript
  assignment (valid at runtime, not yet modeled by ty)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reserve all_grids for the nested regime-keyed mapping; use grids when the
parameter holds a flat dict for one regime.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Avoids ugly [pd.NA] * n list construction and removes ty: ignore needed
because StringArray doesn't type-check with numpy array assignment.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Callers now pass batch_sizes explicitly instead of relying on None →
all-zeros conversion inside productmap. This pushes the semantics to the
right level: each call site declares its batching intent.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Consistent with all other functions in dispatchers.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_index

- Replace `ty: ignore[invalid-return-type]` with explicit `cast("FloatND", ...)`
  in `_base_productmap_batched`
- Rename `_LevelMapping.label_to_index` to `get_code_from_label` to clarify it
  is a callable, not a mapping; tighten type from `Callable[[object], int]` to
  `Callable[[str], int]`
- Drop lambda wrapper in `_grid_level_mapping`, use `__getitem__` directly

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thank you!!!

I made a bunch of cosmetic edits, but did not touch the core changes. Do double-check before merging, please!

@mj023 mj023 merged commit 20669fd into main Apr 7, 2026
10 checks passed
@mj023 mj023 deleted the batched_vmap branch April 7, 2026 23:20
hmgaudecker added a commit that referenced this pull request Apr 8, 2026
productmap now requires batch_sizes as a keyword argument (#280).
Both the function_representation and dispatchers notebooks were missing
it, causing doc builds to fail.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants