Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 80 additions & 33 deletions src/hmf/_internals/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import warnings
from contextlib import suppress
from copy import deepcopy
from functools import update_wrapper

Expand All @@ -24,6 +25,73 @@ def hidden_loc(obj, name):
return ("_" + obj.__class__.__name__ + "__" + name).replace("___", "__")


def _rollback_failed_quantity_index(self: object, name: str) -> None:
"""Remove partial dependency bookkeeping left by a failed cached quantity."""
prop = hidden_loc(self, name)
recalc = getattr(self, hidden_loc(self, "recalc"))
recalc_prpa = getattr(self, hidden_loc(self, "recalc_prop_par"))
activeq = getattr(self, hidden_loc(self, "active_q"))
recalc_papr = getattr(self, hidden_loc(self, "recalc_par_prop"))

activeq.discard(name)
recalc.pop(name, None)
recalc_prpa.pop(name, None)

for quantities in recalc_papr.values():
quantities.discard(name)

with suppress(AttributeError):
delattr(self, prop)

for subframework in [
getattr(self, s) for s in getattr(self, hidden_loc(self, "subframeworks"), set())
]:
sub_recalc = getattr(subframework, hidden_loc(subframework, "recalc"))
sub_recalc_prpa = getattr(subframework, hidden_loc(subframework, "recalc_prop_par"))
sub_activeq = getattr(subframework, hidden_loc(subframework, "active_q"))
sub_recalc_papr = getattr(subframework, hidden_loc(subframework, "recalc_par_prop"))

sub_name = ":" + name
sub_activeq.discard(sub_name)
sub_recalc.pop(sub_name, None)
sub_recalc_prpa.pop(sub_name, None)

for quantities in sub_recalc_papr.values():
quantities.discard(sub_name)


def _finalize_quantity_index(self: object, name: str, supered: bool) -> None:
"""Persist dependency bookkeeping for a successfully evaluated cached quantity."""
recalc = getattr(self, hidden_loc(self, "recalc"))
recalc_prpa = getattr(self, hidden_loc(self, "recalc_prop_par"))
activeq = getattr(self, hidden_loc(self, "active_q"))
recalc_papr = getattr(self, hidden_loc(self, "recalc_par_prop"))

for par in recalc_prpa[name]:
recalc_papr[par].add(name)

if not supered:
recalc_prpa[name] = deepcopy(recalc_prpa[name])
activeq.remove(name)

recalc[name] = False

for subframework in [
getattr(self, s) for s in getattr(self, hidden_loc(self, "subframeworks"), set())
]:
sub_recalc_prpa = getattr(subframework, hidden_loc(subframework, "recalc_prop_par"))
sub_recalc_papr = getattr(subframework, hidden_loc(subframework, "recalc_par_prop"))
sub_activeq = getattr(subframework, hidden_loc(subframework, "active_q"))

sub_name = ":" + name
if sub_name in sub_recalc_prpa:
for par in sub_recalc_prpa[sub_name]:
sub_recalc_papr[par].add(sub_name)

if sub_name in sub_activeq:
sub_activeq.remove(sub_name)


def cached_quantity(f):
"""
A robust property caching decorator.
Expand Down Expand Up @@ -62,14 +130,12 @@ def _get_property(self):
_recalc = hidden_loc(self, "recalc")
_recalc_prpa = hidden_loc(self, "recalc_prop_par")
_activeq = hidden_loc(self, "active_q")
_recalc_papr = hidden_loc(self, "recalc_par_prop")
_subframeworks = hidden_loc(self, "subframeworks")

# actual objects
recalc = getattr(self, _recalc)
recalc_prpa = getattr(self, _recalc_prpa)
activeq = getattr(self, _activeq)
recalc_papr = getattr(self, _recalc_papr)
subframeworks = [getattr(self, s) for s in getattr(self, _subframeworks, set())]

# First, if this property has already been indexed,
Expand Down Expand Up @@ -112,37 +178,18 @@ def _get_property(self):
recalc_prpa[name] = set() # Empty set to which parameter names will be added
activeq.add(name)

# Go ahead and calculate the value -- each parameter accessed will add itself to the index.
value = f(self)
setattr(self, prop, value)

# Invert the index
for par in recalc_prpa[name]:
recalc_papr[par].add(name)

# Copy index to static dict, and remove the index (so that parameters don't keep
# on trying to add themselves)
if not supered: # If super, don't want to remove the name just yet.
recalc_prpa[name] = deepcopy(recalc_prpa[name])
activeq.remove(name)

# Add entry to master recalc list
recalc[name] = False

# Invert sub-framework indices
subframeworks = [
getattr(self, s) for s in getattr(self, _subframeworks, set())
] # have to get it again, because it's been updated

for s in subframeworks:
if ":" + name in getattr(s, hidden_loc(s, "recalc_prop_par")):
for par in getattr(s, hidden_loc(s, "recalc_prop_par"))[":" + name]:
getattr(s, hidden_loc(s, "recalc_par_prop"))[par].add(":" + name)

if ":" + name in getattr(s, hidden_loc(s, "active_q")):
getattr(s, hidden_loc(s, "active_q")).remove(":" + name)

return value
try:
# Go ahead and calculate the value -- each parameter accessed will add itself
# to the index. If this fails, rollback the partial bookkeeping so the object
# remains in the same state as before the failed access.
value = f(self)
setattr(self, prop, value)
_finalize_quantity_index(self, name, supered)
return value
except Exception:
if not supered:
_rollback_failed_quantity_index(self, name)
raise

update_wrapper(_get_property, f)

Expand Down
32 changes: 32 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def q(self):
return self.p + self.sub.child_q


class _FailingParent(_Parent):
@cached_quantity
def q_fail(self):
_ = self.sub.child_q
raise ValueError("boom")


class _DictLike:
def __init__(self, data):
self._data = data
Expand Down Expand Up @@ -265,3 +272,28 @@ def test_subframework_delete_without_instance():
obj = _Parent()

del obj.sub


def test_failed_cached_quantity_cleans_subframework_bookkeeping():
obj = _FailingParent()
child = obj.sub
_ = child.child_q

initial_child_recalc = dict(getattr(child, hidden_loc(child, "recalc")))
initial_child_prpa = dict(getattr(child, hidden_loc(child, "recalc_prop_par")))
initial_child_papr = {
key: value.copy()
for key, value in getattr(child, hidden_loc(child, "recalc_par_prop")).items()
}

with pytest.raises(ValueError, match="boom"):
obj.q_fail

assert getattr(obj, hidden_loc(obj, "active_q")) == set()
assert "q_fail" not in getattr(obj, hidden_loc(obj, "recalc"))
assert "q_fail" not in getattr(obj, hidden_loc(obj, "recalc_prop_par"))

assert getattr(child, hidden_loc(child, "recalc")) == initial_child_recalc
assert getattr(child, hidden_loc(child, "recalc_prop_par")) == initial_child_prpa
assert getattr(child, hidden_loc(child, "recalc_par_prop")) == initial_child_papr
assert getattr(child, hidden_loc(child, "active_q")) == set()
60 changes: 59 additions & 1 deletion tests/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
import sys
import typing
from copy import deepcopy

import pytest
from deprecation import fail_if_not_removed

import hmf
from hmf import GrowthFactor, MassFunction
from hmf._internals import get_base_component, get_base_components, pluggable
from hmf._internals._cache import cached_quantity, parameter
from hmf._internals._cache import cached_quantity, hidden_loc, parameter
from hmf._internals._framework import Component, Framework, get_mdl, get_model, get_model_
from hmf.density_field.transfer_models import TransferComponent


class _BrokenFramework(Framework):
def __init__(self, a=0, b=1):
self._validate = False
self.a = a
self.b = b
self._validate = True

def validate(self):
pass

@parameter("param")
def a(self, val):
return val

@parameter("param")
def b(self, val):
return val

@cached_quantity
def broken(self):
if self.a == 0:
raise ValueError("boom")
return self.a + self.b

@cached_quantity
def dependent(self):
return self.broken + 1


def test_incorrect_argument():
with pytest.raises(TypeError):
hmf.MassFunction(wrong_arg=3)
Expand Down Expand Up @@ -310,6 +340,34 @@ def validate(self):
assert parent.x == 3


def test_failed_cached_quantity_recovers_after_update():
"""A failed quantity access should not block later successful recomputation."""
framework = _BrokenFramework(a=0, b=1)

with pytest.raises(ValueError, match="boom"):
framework.dependent

framework.update(a=2)

assert framework.dependent == 4


def test_failed_cached_quantity_cleanup_is_transactional():
"""A failed quantity access should clean up cache bookkeeping on error."""
framework = _BrokenFramework(a=0, b=1)
initial_recalc = deepcopy(getattr(framework, hidden_loc(framework, "recalc")))
initial_recalc_prpa = deepcopy(getattr(framework, hidden_loc(framework, "recalc_prop_par")))
initial_recalc_papr = deepcopy(getattr(framework, hidden_loc(framework, "recalc_par_prop")))

with pytest.raises(ValueError, match="boom"):
framework.dependent

assert getattr(framework, hidden_loc(framework, "active_q")) == set()
assert getattr(framework, hidden_loc(framework, "recalc")) == initial_recalc
assert getattr(framework, hidden_loc(framework, "recalc_prop_par")) == initial_recalc_prpa
assert getattr(framework, hidden_loc(framework, "recalc_par_prop")) == initial_recalc_papr


def test_get_dependencies_and_parameter_info(capsys):
class Simple(Framework):
def __init__(self):
Expand Down
Loading