Skip to content

Commit e74e828

Browse files
committed
tweak mul mode
1 parent 14f2a7f commit e74e828

3 files changed

Lines changed: 67 additions & 46 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Derivative(sympy.Derivative, Differentiable, Pickable):
8989
evaluation are `x0`, `fd_order` and `side`.
9090
"""
9191

92-
_fd_priority = 3
92+
_fd_priority = .9
9393

9494
__rargs__ = ('expr', '*dims')
9595
__rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs',
@@ -472,7 +472,7 @@ def T(self):
472472

473473
return self._rebuild(transpose=adjoint)
474474

475-
def _eval_at(self, func, **kwargs):
475+
def _eval_at(self, func, mul_first=False, **kwargs):
476476
"""
477477
Evaluates the derivative at the location of `func`. It is necessary for staggered
478478
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
@@ -522,7 +522,8 @@ def _eval_at(self, func, **kwargs):
522522
return self._rebuild(self.expr, **rkw)
523523
args = [self.expr.func(*v) for v in mapper.values()]
524524
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
525-
args = [self._rebuild(a)._eval_at(func, **kwargs) for a in args]
525+
args = [self._rebuild(a)._eval_at(func, mul_first=mul_first, **kwargs)
526+
for a in args]
526527
return self.expr.func(*args)
527528
elif self.expr.is_Mul:
528529
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear

devito/finite_differences/differentiable.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,11 @@ def coefficients(self):
185185
return sorted(coefficients, key=key, reverse=True)[0]
186186

187187
def _eval_at(self, func, **kwargs):
188-
if not func.is_Staggered:
188+
if not func.is_Staggered and not self.is_Staggered:
189189
# Cartesian grid, do no waste time
190190
return self
191-
<<<<<<< HEAD
192-
return self.func(*[
193-
getattr(a, '_eval_at', lambda x: a)(func) for a in self.args # noqa: B023
194-
]) # false positive
195-
=======
196191
return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs)
197192
for a in self.args])
198-
>>>>>>> 198aedce4 (api: add mul interp mode)
199193

200194
def _subs(self, old, new, **hints):
201195
if old == self:
@@ -497,14 +491,19 @@ def has_free(self, *patterns):
497491
return all(i in self.free_symbols for i in patterns)
498492

499493

494+
<<<<<<< HEAD
500495
def highest_priority(diff_op):
501496
if not diff_op._args_diff:
502497
return diff_op
503498

499+
=======
500+
def highest_priority(DiffOp, ref=None):
501+
>>>>>>> c34bae6f0 (tweak mul mode)
504502
# We want to get the object with highest priority
505503
# We also need to make sure that the object with the largest
506504
# set of dimensions is used when multiple ones with the same
507505
# priority appear
506+
<<<<<<< HEAD
508507
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
509508
<<<<<<< HEAD
510509
prio_func = sorted(diff_op._args_diff, key=prio, reverse=True)[0]
@@ -514,6 +513,14 @@ def highest_priority(diff_op):
514513
return highest_priority(prio_func)
515514
return prio_func
516515
=======
516+
=======
517+
def stagg(x):
518+
try:
519+
return int(x.staggered == ref.staggered)
520+
except AttributeError:
521+
return 0
522+
prio = lambda x: (stagg(x), getattr(x, '_fd_priority', 0), len(x.dimensions))
523+
>>>>>>> c34bae6f0 (tweak mul mode)
517524
args = DiffOp._args_diff
518525
if not args:
519526
return DiffOp
@@ -692,50 +699,62 @@ def _eval_at(self, func, mul_first=False, **kwargs):
692699
if not mul_first:
693700
return super()._eval_at(func, mul_first=mul_first)
694701

695-
# Not a basic a*b*c... expression, just defer to superclass
702+
# Same staggering, no need to interpolate
703+
if self.staggered == func.staggered:
704+
return self
705+
706+
# Get highest priority function for evaluation
707+
func0 = highest_priority(self, ref=func)
708+
709+
# Not a basic a*b*c... expression, expand
696710
if any(isinstance(f, DifferentiableOp) for f in self.args):
697-
return super()._eval_at(func, mul_first=mul_first)
711+
return diffify(self._eval_expand_mul())._eval_at(func, mul_first=mul_first)
698712

699713
# Split Derivative and Differentiable args
700714
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))
701715

716+
# Evaluate all at highest priority function
702717
if derivs:
703-
derivs = Differentiable._eval_at(self.func(*derivs), func,
704-
mul_first=mul_first)
718+
derivs = self.func(*[d._eval_at(func, mul_first=mul_first) for d in derivs])
705719
else:
706720
derivs = 1
707721

708722
if not other:
709723
return derivs
710-
elif len(other) > 1:
711-
expr = self.func(*other)._gather_for_diff
712-
else:
713-
expr = other[0]
724+
expr = self.func(*other)
714725

715726
# Non differentiable expr (e.g., number)
716727
if not isinstance(expr, Differentiable):
717728
return self.func(derivs, expr)
718729

719-
# Build mapper for dimensions that need to be interpolated
720-
mapper = {}
721-
for d in self.dimensions:
722-
try:
723-
if self.indices_ref[d] is not func.indices_ref[d]:
724-
mapper[d] = func.indices_ref[d]
725-
except KeyError:
726-
pass
727-
728-
# Nothing to interpolate
729-
if not mapper:
730-
return super()._eval_at(func, mul_first=mul_first)
731-
732-
# Interpolate expr at the required indices
733-
interp = expr.diff(*mapper.keys(), deriv_order=[0 for _ in mapper],
734-
fd_order=[self.interp_order for _ in mapper],
735-
x0=mapper)
736-
737-
# Return the full expression with Derivatives
738-
return self.func(derivs, interp)
730+
# Evaluate expression at func_args
731+
print(f"\nEvaluating expr {expr} at func0 {func0} for func {func} from {self}")
732+
expr = Differentiable._eval_at(expr, func0, mul_first=False)
733+
734+
# Interpolate derivatives at func0
735+
x0 = {d: v for d, v in func0.indices_ref.getters.items()
736+
if not d.is_Time and v is not func.indices_ref.getters.get(d, d)}
737+
if x0 and not derivs == 1:
738+
print(f"Interpolating derivs {derivs} x0={derivs.x0} at {x0}")
739+
derivs = derivs.diff(*x0.keys(), deriv_order=(0,)*len(x0),
740+
fd_order=(self.interp_order,)*len(x0),
741+
x0=x0)
742+
newexpr = self.func(derivs, expr)
743+
744+
# Finally at func
745+
if not func.staggered == func0.staggered:
746+
x0_f = {d: v for d, v in func.indices_ref.getters.items()
747+
if not d.is_Time and v is not func0.indices_ref.getters.get(d)}
748+
if x0_f:
749+
print(f"Final interpolation of derivs {self.func(derivs, expr)} at func {x0_f}")
750+
return newexpr.diff(*x0_f.keys(), deriv_order=(0,)*len(x0_f),
751+
fd_order=(self.interp_order,)*len(x0_f),
752+
x0=x0_f)
753+
else:
754+
return newexpr
755+
else:
756+
# Return the full expression with Derivatives
757+
return newexpr
739758

740759

741760
class Pow(DifferentiableOp, sympy.Pow):

devito/types/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@ class Staggering(DimensionTuple):
5757
def on_node(self):
5858
return not self or all(s == 0 for s in self)
5959

60-
@property
61-
def _ref(self):
62-
if not self:
63-
return None
64-
elif self.on_node:
65-
return NODE
66-
else:
67-
return tuple(d for d, s in zip(self.getters, self, strict=True) if s == 1)
60+
def __eq__(self, other):
61+
if not isinstance(other, Staggering):
62+
return False
63+
all_same = self and other and all(a == b for a, b in zip(self, other))
64+
all_node = self.on_node and other.on_node
65+
66+
return all_same or all_node
67+
68+
__hash__ = DimensionTuple.__hash__
6869

6970

7071
class IgnoreDimSort(tuple):

0 commit comments

Comments
 (0)