From eaf86df6c697288a56b83db915c5a3c6abb65d82 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Sat, 23 May 2026 22:00:22 +0200 Subject: [PATCH 1/2] Fix cached-quantity rollback after exceptions Ensure first-time cached quantity evaluation cleans up partial cache bookkeeping when the underlying calculation raises. Add regressions for both recovery after update and full transactional cleanup. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/hmf/_internals/_cache.py | 113 +++++++++++++++++++++++++---------- tests/test_framework.py | 60 ++++++++++++++++++- 2 files changed, 139 insertions(+), 34 deletions(-) diff --git a/src/hmf/_internals/_cache.py b/src/hmf/_internals/_cache.py index ba66673..de2d0fd 100644 --- a/src/hmf/_internals/_cache.py +++ b/src/hmf/_internals/_cache.py @@ -9,6 +9,7 @@ """ import warnings +from contextlib import suppress from copy import deepcopy from functools import update_wrapper @@ -24,6 +25,73 @@ def hidden_loc(obj, name): return ("_" + obj.__class__.__name__ + "__" + name).replace("___", "__") +def _rollback_failed_quantity_index(self: object, name: str) -> None: + """Remove partial dependency bookkeeping left by a failed cached quantity.""" + prop = hidden_loc(self, name) + recalc = getattr(self, hidden_loc(self, "recalc")) + recalc_prpa = getattr(self, hidden_loc(self, "recalc_prop_par")) + activeq = getattr(self, hidden_loc(self, "active_q")) + recalc_papr = getattr(self, hidden_loc(self, "recalc_par_prop")) + + activeq.discard(name) + recalc.pop(name, None) + recalc_prpa.pop(name, None) + + for quantities in recalc_papr.values(): + quantities.discard(name) + + with suppress(AttributeError): + delattr(self, prop) + + for subframework in [ + getattr(self, s) for s in getattr(self, hidden_loc(self, "subframeworks"), set()) + ]: + sub_recalc = getattr(subframework, hidden_loc(subframework, "recalc")) + sub_recalc_prpa = getattr(subframework, hidden_loc(subframework, "recalc_prop_par")) + sub_activeq = getattr(subframework, hidden_loc(subframework, "active_q")) + sub_recalc_papr = getattr(subframework, hidden_loc(subframework, "recalc_par_prop")) + + sub_name = ":" + name + sub_activeq.discard(sub_name) + sub_recalc.pop(sub_name, None) + sub_recalc_prpa.pop(sub_name, None) + + for quantities in sub_recalc_papr.values(): + quantities.discard(sub_name) + + +def _finalize_quantity_index(self: object, name: str, supered: bool) -> None: + """Persist dependency bookkeeping for a successfully evaluated cached quantity.""" + recalc = getattr(self, hidden_loc(self, "recalc")) + recalc_prpa = getattr(self, hidden_loc(self, "recalc_prop_par")) + activeq = getattr(self, hidden_loc(self, "active_q")) + recalc_papr = getattr(self, hidden_loc(self, "recalc_par_prop")) + + for par in recalc_prpa[name]: + recalc_papr[par].add(name) + + if not supered: + recalc_prpa[name] = deepcopy(recalc_prpa[name]) + activeq.remove(name) + + recalc[name] = False + + for subframework in [ + getattr(self, s) for s in getattr(self, hidden_loc(self, "subframeworks"), set()) + ]: + sub_recalc_prpa = getattr(subframework, hidden_loc(subframework, "recalc_prop_par")) + sub_recalc_papr = getattr(subframework, hidden_loc(subframework, "recalc_par_prop")) + sub_activeq = getattr(subframework, hidden_loc(subframework, "active_q")) + + sub_name = ":" + name + if sub_name in sub_recalc_prpa: + for par in sub_recalc_prpa[sub_name]: + sub_recalc_papr[par].add(sub_name) + + if sub_name in sub_activeq: + sub_activeq.remove(sub_name) + + def cached_quantity(f): """ A robust property caching decorator. @@ -62,14 +130,12 @@ def _get_property(self): _recalc = hidden_loc(self, "recalc") _recalc_prpa = hidden_loc(self, "recalc_prop_par") _activeq = hidden_loc(self, "active_q") - _recalc_papr = hidden_loc(self, "recalc_par_prop") _subframeworks = hidden_loc(self, "subframeworks") # actual objects recalc = getattr(self, _recalc) recalc_prpa = getattr(self, _recalc_prpa) activeq = getattr(self, _activeq) - recalc_papr = getattr(self, _recalc_papr) subframeworks = [getattr(self, s) for s in getattr(self, _subframeworks, set())] # First, if this property has already been indexed, @@ -112,37 +178,18 @@ def _get_property(self): recalc_prpa[name] = set() # Empty set to which parameter names will be added activeq.add(name) - # Go ahead and calculate the value -- each parameter accessed will add itself to the index. - value = f(self) - setattr(self, prop, value) - - # Invert the index - for par in recalc_prpa[name]: - recalc_papr[par].add(name) - - # Copy index to static dict, and remove the index (so that parameters don't keep - # on trying to add themselves) - if not supered: # If super, don't want to remove the name just yet. - recalc_prpa[name] = deepcopy(recalc_prpa[name]) - activeq.remove(name) - - # Add entry to master recalc list - recalc[name] = False - - # Invert sub-framework indices - subframeworks = [ - getattr(self, s) for s in getattr(self, _subframeworks, set()) - ] # have to get it again, because it's been updated - - for s in subframeworks: - if ":" + name in getattr(s, hidden_loc(s, "recalc_prop_par")): - for par in getattr(s, hidden_loc(s, "recalc_prop_par"))[":" + name]: - getattr(s, hidden_loc(s, "recalc_par_prop"))[par].add(":" + name) - - if ":" + name in getattr(s, hidden_loc(s, "active_q")): - getattr(s, hidden_loc(s, "active_q")).remove(":" + name) - - return value + try: + # Go ahead and calculate the value -- each parameter accessed will add itself + # to the index. If this fails, rollback the partial bookkeeping so the object + # remains in the same state as before the failed access. + value = f(self) + setattr(self, prop, value) + _finalize_quantity_index(self, name, supered) + return value + except Exception: + if not supered: + _rollback_failed_quantity_index(self, name) + raise update_wrapper(_get_property, f) diff --git a/tests/test_framework.py b/tests/test_framework.py index 8ea0885..d588526 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -1,5 +1,6 @@ import sys import typing +from copy import deepcopy import pytest from deprecation import fail_if_not_removed @@ -7,11 +8,40 @@ import hmf from hmf import GrowthFactor, MassFunction from hmf._internals import get_base_component, get_base_components, pluggable -from hmf._internals._cache import cached_quantity, parameter +from hmf._internals._cache import cached_quantity, hidden_loc, parameter from hmf._internals._framework import Component, Framework, get_mdl, get_model, get_model_ from hmf.density_field.transfer_models import TransferComponent +class _BrokenFramework(Framework): + def __init__(self, a=0, b=1): + self._validate = False + self.a = a + self.b = b + self._validate = True + + def validate(self): + pass + + @parameter("param") + def a(self, val): + return val + + @parameter("param") + def b(self, val): + return val + + @cached_quantity + def broken(self): + if self.a == 0: + raise ValueError("boom") + return self.a + self.b + + @cached_quantity + def dependent(self): + return self.broken + 1 + + def test_incorrect_argument(): with pytest.raises(TypeError): hmf.MassFunction(wrong_arg=3) @@ -310,6 +340,34 @@ def validate(self): assert parent.x == 3 +def test_failed_cached_quantity_recovers_after_update(): + """A failed quantity access should not block later successful recomputation.""" + framework = _BrokenFramework(a=0, b=1) + + with pytest.raises(ValueError, match="boom"): + framework.dependent + + framework.update(a=2) + + assert framework.dependent == 4 + + +def test_failed_cached_quantity_cleanup_is_transactional(): + """A failed quantity access should clean up cache bookkeeping on error.""" + framework = _BrokenFramework(a=0, b=1) + initial_recalc = deepcopy(getattr(framework, hidden_loc(framework, "recalc"))) + initial_recalc_prpa = deepcopy(getattr(framework, hidden_loc(framework, "recalc_prop_par"))) + initial_recalc_papr = deepcopy(getattr(framework, hidden_loc(framework, "recalc_par_prop"))) + + with pytest.raises(ValueError, match="boom"): + framework.dependent + + assert getattr(framework, hidden_loc(framework, "active_q")) == set() + assert getattr(framework, hidden_loc(framework, "recalc")) == initial_recalc + assert getattr(framework, hidden_loc(framework, "recalc_prop_par")) == initial_recalc_prpa + assert getattr(framework, hidden_loc(framework, "recalc_par_prop")) == initial_recalc_papr + + def test_get_dependencies_and_parameter_info(capsys): class Simple(Framework): def __init__(self): From d32f74e568c48b554c565ef9ed53e9a0d430513c Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Sat, 23 May 2026 22:05:40 +0200 Subject: [PATCH 2/2] Cover subframework cache rollback Add a regression that exercises failed parent quantity evaluation after touching a subframework so Codecov sees the rollback path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/test_cache.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_cache.py b/tests/test_cache.py index 0cb9344..bab8a08 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -72,6 +72,13 @@ def q(self): return self.p + self.sub.child_q +class _FailingParent(_Parent): + @cached_quantity + def q_fail(self): + _ = self.sub.child_q + raise ValueError("boom") + + class _DictLike: def __init__(self, data): self._data = data @@ -265,3 +272,28 @@ def test_subframework_delete_without_instance(): obj = _Parent() del obj.sub + + +def test_failed_cached_quantity_cleans_subframework_bookkeeping(): + obj = _FailingParent() + child = obj.sub + _ = child.child_q + + initial_child_recalc = dict(getattr(child, hidden_loc(child, "recalc"))) + initial_child_prpa = dict(getattr(child, hidden_loc(child, "recalc_prop_par"))) + initial_child_papr = { + key: value.copy() + for key, value in getattr(child, hidden_loc(child, "recalc_par_prop")).items() + } + + with pytest.raises(ValueError, match="boom"): + obj.q_fail + + assert getattr(obj, hidden_loc(obj, "active_q")) == set() + assert "q_fail" not in getattr(obj, hidden_loc(obj, "recalc")) + assert "q_fail" not in getattr(obj, hidden_loc(obj, "recalc_prop_par")) + + assert getattr(child, hidden_loc(child, "recalc")) == initial_child_recalc + assert getattr(child, hidden_loc(child, "recalc_prop_par")) == initial_child_prpa + assert getattr(child, hidden_loc(child, "recalc_par_prop")) == initial_child_papr + assert getattr(child, hidden_loc(child, "active_q")) == set()