Skip to content

Commit be046cb

Browse files
committed
api: fix symmetric interp mode and add lots of tests
1 parent 5e25bfb commit be046cb

11 files changed

Lines changed: 1409 additions & 97 deletions

File tree

devito/core/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _normalize_kwargs(cls, **kwargs):
8787

8888
# Code generation options for derivatives
8989
o['expand'] = oo.pop('expand', cls.EXPAND)
90-
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
90+
o['interp-mode'] = oo.pop('interp-mode', cls.INTERP_MODE)
9191
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9292
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9393
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _normalize_kwargs(cls, **kwargs):
102102

103103
# Code generation options for derivatives
104104
o['expand'] = oo.pop('expand', cls.EXPAND)
105-
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
105+
o['interp-mode'] = oo.pop('interp-mode', cls.INTERP_MODE)
106106
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
107107
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
108108
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/operator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,24 @@ class BasicOperator(Operator):
125125
finite-difference derivatives.
126126
"""
127127

128-
MUL_FIRST = False
128+
INTERP_MODE = 'direct'
129129
"""
130-
When evaluating expressions location, prioritize multiplication
131-
operations.
130+
Interpolation mode used by `Mul._eval_at` when projecting a multi-factor
131+
expression onto a target staggered location:
132+
133+
* `'direct'` (default): each factor is shifted to `func`'s location
134+
independently (`Function._eval_at` per arg). Cheapest stencil; the
135+
mode to pick unless you need an explicitly self-adjoint discretization.
136+
137+
* `'symmetric'`: when every factor lives at a staggering different from
138+
`func`'s, the symmetric form `I * (a * I^T * b)` is built -- all
139+
factors are gathered at the highest-priority "block" location via
140+
`I^T`, multiplied there, and the product is interpolated to `func`
141+
via `I`. Use this for operators whose continuous form decomposes as
142+
`I * A * I^T` (e.g. the elastic stiffness `σ = C ε`).
143+
144+
See `examples/userapi/08_staggered_interp.ipynb` for the maths and a
145+
worked elastic-stiffness example.
132146
"""
133147

134148
DERIV_COLLECT = True

devito/finite_differences/derivative.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,14 @@ class Derivative(sympy.Derivative, Differentiable, Pickable):
8989
evaluation are `x0`, `fd_order` and `side`.
9090
"""
9191

92-
_fd_priority = .9
92+
@cached_property
93+
def _fd_priority(self):
94+
# A Derivative inherits the priority of its underlying expression, so
95+
# that `highest_priority(C*v.dx)` and `highest_priority((C*v).dx)`
96+
# agree on the gather location and the two gathering paths
97+
# (`_gather_for_diff` and `Mul._eval_at(interp_mode='symmetric')`)
98+
# produce consistent answers.
99+
return getattr(self.expr, '_fd_priority', 0)
93100

94101
__rargs__ = ('expr', '*dims')
95102
__rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs',
@@ -472,7 +479,7 @@ def T(self):
472479

473480
return self._rebuild(transpose=adjoint)
474481

475-
def _eval_at(self, func, mul_first=False, **kwargs):
482+
def _eval_at(self, func, interp_mode='direct', **kwargs):
476483
"""
477484
Evaluates the derivative at the location of `func`. It is necessary for staggered
478485
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
@@ -522,7 +529,7 @@ def _eval_at(self, func, mul_first=False, **kwargs):
522529
return self._rebuild(self.expr, **rkw)
523530
args = [self.expr.func(*v) for v in mapper.values()]
524531
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
525-
args = [self._rebuild(a)._eval_at(func, mul_first=mul_first, **kwargs)
532+
args = [self._rebuild(a)._eval_at(func, interp_mode=interp_mode, **kwargs)
526533
for a in args]
527534
return self.expr.func(*args)
528535
elif self.expr.is_Mul:
@@ -594,7 +601,7 @@ def _eval_fd(self, expr, **kwargs):
594601
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
595602
self.deriv_order[0], weights=self.weights,
596603
side=self.side, matvec=self.transpose,
597-
x0=x0_deriv, expand=expand)
604+
x0=self.x0, expand=expand)
598605

599606
# Step 4: Apply substitutions
600607
for e in self._ppsubs:

devito/finite_differences/differentiable.py

Lines changed: 118 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import ChainMap
2+
from contextlib import suppress
23
from functools import cached_property, singledispatch
34
from itertools import product
45

@@ -185,8 +186,10 @@ def coefficients(self):
185186
return sorted(coefficients, key=key, reverse=True)[0]
186187

187188
def _eval_at(self, func, **kwargs):
188-
return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs)
189-
for a in self.args])
189+
return self.func(*[
190+
getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) # noqa: B023
191+
for a in self.args # false positive: lambda is invoked in-place
192+
])
190193

191194
def _subs(self, old, new, **hints):
192195
if old == self:
@@ -576,9 +579,6 @@ def _fd_priority(self):
576579
return super()._fd_priority
577580
return highest_priority(self)._fd_priority
578581

579-
def _eval_at(self, func, **kwargs):
580-
return self
581-
582582

583583
class Add(DifferentiableOp, sympy.Add):
584584
__sympy_class__ = sympy.Add
@@ -668,67 +668,63 @@ def _gather_for_diff(self):
668668
other = self.func(*other)._eval_at(highest_priority(self))
669669
return self.func(other, *derivs)
670670

671-
def _eval_at(self, func, mul_first=False, **kwargs):
672-
# Dont evaluate mul first
673-
if not mul_first:
674-
return super()._eval_at(func, mul_first=mul_first)
675-
676-
# Same staggering, no need to interpolate
677-
if self.staggered == func.staggered:
678-
return self
679-
680-
# Get highest priority function for evaluation
681-
func0 = highest_priority(self, ref=func)
682-
683-
# Not a basic a*b*c... expression, expand
684-
if any(isinstance(f, DifferentiableOp) for f in self.args):
685-
return diffify(self._eval_expand_mul())._eval_at(func, mul_first=mul_first)
686-
687-
# Split Derivative and Differentiable args
688-
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))
689-
690-
# Evaluate all at highest priority function
691-
if derivs:
692-
derivs = self.func(*[d._eval_at(func, mul_first=mul_first) for d in derivs])
693-
else:
694-
derivs = 1
695-
696-
if not other:
697-
return derivs
698-
expr = self.func(*other)
699-
700-
# Non differentiable expr (e.g., number)
701-
if not isinstance(expr, Differentiable):
702-
return self.func(derivs, expr)
703-
704-
# Evaluate expression at func_args
705-
print(f"\nEvaluating expr {expr} at func0 {func0} for func {func} from {self}")
706-
expr = Differentiable._eval_at(expr, func0, mul_first=False)
707-
708-
# Interpolate derivatives at func0
709-
x0 = {d: v for d, v in func0.indices_ref.getters.items()
710-
if not d.is_Time and v is not func.indices_ref.getters.get(d, d)}
711-
if x0 and not derivs == 1:
712-
print(f"Interpolating derivs {derivs} x0={derivs.x0} at {x0}")
713-
derivs = derivs.diff(*x0.keys(), deriv_order=(0,)*len(x0),
714-
fd_order=(self.interp_order,)*len(x0),
715-
x0=x0)
716-
newexpr = self.func(derivs, expr)
717-
718-
# Finally at func
719-
if not func.staggered == func0.staggered:
720-
x0_f = {d: v for d, v in func.indices_ref.getters.items()
721-
if not d.is_Time and v is not func0.indices_ref.getters.get(d)}
722-
if x0_f:
723-
print(f"Final interpolation of derivs {self.func(derivs, expr)} at func {x0_f}")
724-
return newexpr.diff(*x0_f.keys(), deriv_order=(0,)*len(x0_f),
725-
fd_order=(self.interp_order,)*len(x0_f),
726-
x0=x0_f)
671+
def _eval_at(self, func, interp_mode='direct', **kwargs):
672+
"""
673+
Evaluate a Mul at the location of `func`.
674+
675+
Two modes:
676+
677+
- `interp_mode='direct'` (default): per-arg evaluation; each factor is
678+
independently evaluated at `func`'s location via
679+
`Differentiable._eval_at`.
680+
681+
- `interp_mode='symmetric'`: when every Differentiable factor has a
682+
staggering different from `func`'s, apply the `I * (a * I^T * b)`
683+
form:
684+
685+
1. Pick a `block` location -- the highest-priority factor's
686+
staggering (NODE is the highest priority, so coefficient-like
687+
NODE factors win, as in the `I * C * I^T` elastic stiffness
688+
pattern). Each factor not at the block is brought there via
689+
`I^T` (an explicit 0-order FD interpolation operator).
690+
Derivatives additionally set `x0` on their own derivative
691+
dimensions to `func`'s indices.
692+
2. The product is formed at `block`'s location.
693+
3. The whole product is interpolated to `func` via `I` (an
694+
explicit 0-order FD operator).
695+
696+
When the trigger does not hold (e.g. some factor already matches
697+
`func`'s staggering), we fall back to `direct`.
698+
"""
699+
if interp_mode != 'symmetric':
700+
return super()._eval_at(func, **kwargs)
701+
702+
diff_args = [a for a in self.args if isinstance(a, Differentiable)]
703+
other_args = [a for a in self.args if not isinstance(a, Differentiable)]
704+
705+
# Symmetric form requires every Differentiable factor to differ from
706+
# func; otherwise direct evaluation is cleaner and equivalent.
707+
if len(diff_args) < 2 or \
708+
any(a.staggered == func.staggered for a in diff_args):
709+
return super()._eval_at(func, **kwargs)
710+
711+
block_indices = highest_priority(self).indices_ref
712+
713+
# Bring each factor to block's location (I^T where needed)
714+
new_factors = list(other_args)
715+
for a in diff_args:
716+
if isinstance(a, sympy.Derivative):
717+
source = _post_x0_indices(a, func)
718+
a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims
719+
if dim in func.indices_ref.getters})
727720
else:
728-
return newexpr
729-
else:
730-
# Return the full expression with Derivatives
731-
return newexpr
721+
source = a.indices_ref
722+
new_factors.append(_interp_at(a, source, block_indices,
723+
self.interp_order))
724+
725+
# Final I from block's location to func
726+
return _interp_at(self.func(*new_factors), block_indices,
727+
func.indices_ref, self.interp_order)
732728

733729

734730
class Pow(DifferentiableOp, sympy.Pow):
@@ -1255,6 +1251,63 @@ def _diff2sympy(obj):
12551251
evalf_table[Pow] = evalf_table[sympy.Pow]
12561252

12571253

1254+
def _interp_mapper(source, target, dims):
1255+
"""
1256+
Build a `{dim: target_index}` mapper for dimensions in `dims` where
1257+
`source[dim]` differs from `target[dim]`.
1258+
1259+
`source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain
1260+
dict or a `DimensionTuple`). Dimensions missing from either side are
1261+
skipped silently.
1262+
"""
1263+
mapper = {}
1264+
for d in dims:
1265+
try:
1266+
s = source[d]
1267+
t = target[d]
1268+
except (KeyError, IndexError):
1269+
continue
1270+
if s is not t:
1271+
mapper[d] = t
1272+
return mapper
1273+
1274+
1275+
def _interp_at(expr, source, target, interp_order):
1276+
"""
1277+
Build a symbolic 0-order FD interpolation operator on `expr` that maps
1278+
values from `source` indices to `target` indices, only on the
1279+
dimensions where the two locations differ.
1280+
"""
1281+
if not isinstance(expr, Differentiable):
1282+
return expr
1283+
1284+
mapper = _interp_mapper(source, target, expr.dimensions)
1285+
if not mapper:
1286+
return expr
1287+
1288+
return expr.diff(*mapper.keys(),
1289+
deriv_order=(0,) * len(mapper),
1290+
fd_order=(interp_order,) * len(mapper),
1291+
x0=mapper)
1292+
1293+
1294+
def _post_x0_indices(deriv, func):
1295+
"""
1296+
Conceptual indices of `deriv` after setting `x0` on its own derivative
1297+
dimensions to `func`'s indices. Derivative dims take `func`'s indices;
1298+
other dims keep the underlying expression's natural location (so that
1299+
`interp_for_fd` does not introduce a spurious second shift).
1300+
"""
1301+
ref = {}
1302+
for dim in deriv.dimensions:
1303+
if dim in deriv.dims and dim in func.indices_ref.getters:
1304+
ref[dim] = func.indices_ref[dim]
1305+
else:
1306+
with suppress(KeyError):
1307+
ref[dim] = deriv.indices_ref[dim]
1308+
return ref
1309+
1310+
12581311
# Interpolation for finite differences
12591312
@singledispatch
12601313
def interp_for_fd(expr, x0, **kwargs):

devito/operator/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _lower_exprs(cls, expressions, **kwargs):
341341
* Shift indices for domain alignment.
342342
"""
343343
expand = kwargs['options'].get('expand', True)
344-
mul_first = kwargs['options'].get('eval-mul-first', False)
344+
interp_mode = kwargs['options'].get('interp-mode', 'direct')
345345

346346
# Specialization is performed on unevaluated expressions
347347
expressions = cls._specialize_dsl(expressions, **kwargs)
@@ -352,7 +352,7 @@ def _lower_exprs(cls, expressions, **kwargs):
352352
# ModuloDimensions
353353
if not expand:
354354
expand = lambda d: d.is_Stepping
355-
expressions = flatten([i._evaluate(expand=expand, mul_first=mul_first)
355+
expressions = flatten([i._evaluate(expand=expand, interp_mode=interp_mode)
356356
for i in expressions])
357357

358358
# Scalarize the tensor equations, if any

devito/types/dense.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from devito.deprecations import deprecations
1616
from devito.exceptions import InvalidArgument
1717
from devito.finite_differences import Differentiable, generate_fd_shortcuts
18+
from devito.finite_differences.differentiable import _interp_mapper
1819
from devito.finite_differences.tools import fd_weights_registry
1920
from devito.logger import debug, warning
2021
from devito.mpi import MPI
@@ -1116,22 +1117,24 @@ def __fd_setup__(self):
11161117

11171118
@cached_property
11181119
def _fd_priority(self):
1120+
# NODE takes precedence: coefficients are conventionally stored at the
1121+
# cell centre, so when we gather a product onto a single location
1122+
# (either via _gather_for_diff or symmetric Mul._eval_at), NODE is the
1123+
# natural one to pick.
11191124
return 1.2 if self.staggered.on_node else 1.1
11201125

11211126
def _eval_at(self, func, **kwargs):
11221127
if self.staggered == func.staggered or self.interp_order == 0:
11231128
return self
11241129

1125-
mapper = {}
1126-
for d in self.dimensions:
1127-
try:
1128-
if self.indices_ref[d] is not func.indices_ref[d]:
1129-
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
1130-
mapper[self.indices_ref[d]] = f_idx
1131-
except KeyError:
1132-
pass
1130+
# Dims where self and func indices differ -> {dim: func_idx}
1131+
diff = _interp_mapper(self.indices_ref, func.indices_ref, self.dimensions)
1132+
1133+
# Translate into a subs mapper {self_idx: func_idx} aligned on self's dims
1134+
subs_map = {self.indices_ref[d]: t._subs(func.dimensions[d], d)
1135+
for d, t in diff.items()}
11331136

1134-
return self.subs(mapper)
1137+
return self.subs(subs_map)
11351138

11361139
@classmethod
11371140
def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):

devito/types/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,14 @@ def on_node(self):
5858
return not self or all(s == 0 for s in self)
5959

6060
def __eq__(self, other):
61-
if not isinstance(other, Staggering):
62-
return False
63-
all_same = self and other and all(a == b for a, b in zip(self, other))
64-
all_node = self.on_node and other.on_node
61+
# Two empty-or-all-zero Staggerings are equivalent regardless of arity
62+
# (a Function declared with `staggered=NODE` and one declared without
63+
# both live at the cell centre).
64+
if isinstance(other, Staggering) and self.on_node and other.on_node:
65+
return True
66+
return tuple.__eq__(self, other)
6567

66-
return all_same or all_node
68+
__hash__ = DimensionTuple.__hash__
6769

6870
@property
6971
def _ref(self):
@@ -74,8 +76,6 @@ def _ref(self):
7476
else:
7577
return tuple(d for d, s in zip(self.getters, self, strict=True) if s == 1)
7678

77-
__hash__ = DimensionTuple.__hash__
78-
7979

8080
class IgnoreDimSort(tuple):
8181
"""A tuple subclass used to wrap the implicit_dims to indicate

0 commit comments

Comments
 (0)