Conversation
Benchmark results (HEAD only — no baseline comparison available)
No merge-base results found locally. Run benchmarks on main first for a comparison. |
|
I wasn't yet able to create a good benchmark for a model where batching actually helps. It needs a model that is sufficiently complex, so the compiler can't optimize it well, but it needs to still be runnable somewhat quickly. If using batches actually helps seems to really depend on the model. With Marvins retirement model it nearly cut the memory usage in half, for the Mahler & Yum Model it does very little. I also fixed an error in the MY model input creation and removed one of the tests, because |
Code reviewFound 7 issues:
Lines 82 to 95 in c729e8e
pylcm/src/lcm/regime_building/processing.py Lines 1267 to 1269 in c729e8e
pylcm/src/lcm/regime_building/max_Q_over_a.py Lines 21 to 57 in c729e8e
pylcm/src/lcm/utils/dispatchers.py Lines 221 to 232 in c729e8e
pylcm/src/lcm/regime_building/max_Q_over_a.py Lines 60 to 85 in c729e8e
pylcm/src/lcm/utils/dispatchers.py Lines 228 to 231 in c729e8e 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
- Exclude batch_size from _ShockGrid._param_field_names (and subclass overrides) so it no longer leaks into .params as a distribution parameter - Add batch_size property to Grid base class, removing ty:ignore workaround - Add batch_sizes to get_max_Q_over_a docstring Args section - Fix stale vmap references in dispatchers.py: rename vmapped → mapped, update docstrings/comments to reference jax.lax.map, fix product_map → productmap typo, fix "axe" → "axis" - Add parametrized test verifying batch_size > 0 produces identical results - Reject non-zero batch_size on action grids with a clear error message - Filter batch_sizes to state names only in processing.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Wrap dict.__getitem__ in lambda for _LevelMapping.label_to_index (ty requires Callable[[object], int], not bound method) - Remove stale ty:ignore[no-matching-overload] in result.py - Add ty:ignore[invalid-assignment] for pandas StringArray subscript assignment (valid at runtime, not yet modeled by ty) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reserve all_grids for the nested regime-keyed mapping; use grids when the parameter holds a flat dict for one regime. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Avoids ugly [pd.NA] * n list construction and removes ty: ignore needed because StringArray doesn't type-check with numpy array assignment. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Callers now pass batch_sizes explicitly instead of relying on None → all-zeros conversion inside productmap. This pushes the semantics to the right level: each call site declares its batching intent. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Consistent with all other functions in dispatchers.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
7a29cea to
a1b54b3
Compare
…_index
- Replace `ty: ignore[invalid-return-type]` with explicit `cast("FloatND", ...)`
in `_base_productmap_batched`
- Rename `_LevelMapping.label_to_index` to `get_code_from_label` to clarify it
is a callable, not a mapping; tighten type from `Callable[[object], int]` to
`Callable[[str], int]`
- Drop lambda wrapper in `_grid_level_mapping`, use `__getitem__` directly
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
a1b54b3 to
fefbc91
Compare
hmgaudecker
left a comment
There was a problem hiding this comment.
Very nice, thank you!!!
I made a bunch of cosmetic edits, but did not touch the core changes. Do double-check before merging, please!
productmap now requires batch_sizes as a keyword argument (#280). Both the function_representation and dispatchers notebooks were missing it, causing doc builds to fail. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Problem
Sometimes running a model is not possible because of memory restrictions. The nested
vmapscan lead to JAX creating large arrays for intermediate results, that can be temporarily saved in the GPU memory. Usually these arrays have the dimensions of the State-Action-Space of the model, so looping over batches of half the grid size along one of its dimensions can already halve the peak memory usage. The batching comes at a cost though, the execution time will get progressively worse the smaller the batch size. For big batches the drop in speed is bigger than I would have expected, given that not all the computations can happen at the same time anyways.New feature
This PR implements a batched version of
productmap. The user can for each grid specify the batch size for each states grid. Instead of usingvmapto map theQ_and_F_Functionalong this grid,jax.lax.mapwill be used, which will then either loop over the batches of gridpoints or ifbatch_size=0, work likevmap. The batched version will only be used during the solution, as the State-Action-Space for the simulation is already much smaller, as it only depends on the number of simulated subjects.Tasks
_base_productmapare needed