Skip to content

Commit 14f2a7f

Browse files
committed
api: add mul interp mode
1 parent 107d256 commit 14f2a7f

10 files changed

Lines changed: 92 additions & 13 deletions

File tree

devito/core/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _normalize_kwargs(cls, **kwargs):
8787

8888
# Code generation options for derivatives
8989
o['expand'] = oo.pop('expand', cls.EXPAND)
90+
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
9091
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9192
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9293
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def _normalize_kwargs(cls, **kwargs):
102102

103103
# Code generation options for derivatives
104104
o['expand'] = oo.pop('expand', cls.EXPAND)
105+
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
105106
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
106107
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
107108
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/operator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class BasicOperator(Operator):
125125
finite-difference derivatives.
126126
"""
127127

128+
MUL_FIRST = False
129+
"""
130+
When evaluating expressions location, prioritize multiplication
131+
operations.
132+
"""
133+
128134
DERIV_COLLECT = True
129135
"""
130136
Factorize finite-difference derivatives exploiting the linearity of the FD

devito/finite_differences/derivative.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def T(self):
472472

473473
return self._rebuild(transpose=adjoint)
474474

475-
def _eval_at(self, func):
475+
def _eval_at(self, func, **kwargs):
476476
"""
477477
Evaluates the derivative at the location of `func`. It is necessary for staggered
478478
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
@@ -522,7 +522,7 @@ def _eval_at(self, func):
522522
return self._rebuild(self.expr, **rkw)
523523
args = [self.expr.func(*v) for v in mapper.values()]
524524
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
525-
args = [self._rebuild(a)._eval_at(func) for a in args]
525+
args = [self._rebuild(a)._eval_at(func, **kwargs) for a in args]
526526
return self.expr.func(*args)
527527
elif self.expr.is_Mul:
528528
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
@@ -593,7 +593,7 @@ def _eval_fd(self, expr, **kwargs):
593593
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
594594
self.deriv_order[0], weights=self.weights,
595595
side=self.side, matvec=self.transpose,
596-
x0=self.x0, expand=expand)
596+
x0=x0_deriv, expand=expand)
597597

598598
# Step 4: Apply substitutions
599599
for e in self._ppsubs:

devito/finite_differences/differentiable.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,18 @@ def coefficients(self):
184184
key = lambda x: coeff_priority.get(x, -1)
185185
return sorted(coefficients, key=key, reverse=True)[0]
186186

187-
def _eval_at(self, func):
187+
def _eval_at(self, func, **kwargs):
188188
if not func.is_Staggered:
189189
# Cartesian grid, do no waste time
190190
return self
191+
<<<<<<< HEAD
191192
return self.func(*[
192193
getattr(a, '_eval_at', lambda x: a)(func) for a in self.args # noqa: B023
193194
]) # false positive
195+
=======
196+
return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs)
197+
for a in self.args])
198+
>>>>>>> 198aedce4 (api: add mul interp mode)
194199

195200
def _subs(self, old, new, **hints):
196201
if old == self:
@@ -501,12 +506,20 @@ def highest_priority(diff_op):
501506
# set of dimensions is used when multiple ones with the same
502507
# priority appear
503508
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
509+
<<<<<<< HEAD
504510
prio_func = sorted(diff_op._args_diff, key=prio, reverse=True)[0]
505511

506512
# The highest priority must be a Function
507513
if not isinstance(prio_func, AbstractFunction):
508514
return highest_priority(prio_func)
509515
return prio_func
516+
=======
517+
args = DiffOp._args_diff
518+
if not args:
519+
return DiffOp
520+
else:
521+
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
522+
>>>>>>> 198aedce4 (api: add mul interp mode)
510523

511524

512525
class DifferentiableOp(Differentiable):
@@ -574,11 +587,16 @@ class DifferentiableFunction(DifferentiableOp):
574587
def __new__(cls, *args, **kwargs):
575588
return cls.__sympy_class__.__new__(cls, *args, **kwargs)
576589

590+
<<<<<<< HEAD
577591
@property
578592
def _fd_priority(self):
579593
if highest_priority(self) is self:
580594
return super()._fd_priority
581595
return highest_priority(self)._fd_priority
596+
=======
597+
def _eval_at(self, func, **kwargs):
598+
return self
599+
>>>>>>> 198aedce4 (api: add mul interp mode)
582600

583601

584602
class Add(DifferentiableOp, sympy.Add):
@@ -669,6 +687,56 @@ def _gather_for_diff(self):
669687
other = self.func(*other)._eval_at(highest_priority(self))
670688
return self.func(other, *derivs)
671689

690+
def _eval_at(self, func, mul_first=False, **kwargs):
691+
# Dont evaluate mul first
692+
if not mul_first:
693+
return super()._eval_at(func, mul_first=mul_first)
694+
695+
# Not a basic a*b*c... expression, just defer to superclass
696+
if any(isinstance(f, DifferentiableOp) for f in self.args):
697+
return super()._eval_at(func, mul_first=mul_first)
698+
699+
# Split Derivative and Differentiable args
700+
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))
701+
702+
if derivs:
703+
derivs = Differentiable._eval_at(self.func(*derivs), func,
704+
mul_first=mul_first)
705+
else:
706+
derivs = 1
707+
708+
if not other:
709+
return derivs
710+
elif len(other) > 1:
711+
expr = self.func(*other)._gather_for_diff
712+
else:
713+
expr = other[0]
714+
715+
# Non differentiable expr (e.g., number)
716+
if not isinstance(expr, Differentiable):
717+
return self.func(derivs, expr)
718+
719+
# Build mapper for dimensions that need to be interpolated
720+
mapper = {}
721+
for d in self.dimensions:
722+
try:
723+
if self.indices_ref[d] is not func.indices_ref[d]:
724+
mapper[d] = func.indices_ref[d]
725+
except KeyError:
726+
pass
727+
728+
# Nothing to interpolate
729+
if not mapper:
730+
return super()._eval_at(func, mul_first=mul_first)
731+
732+
# Interpolate expr at the required indices
733+
interp = expr.diff(*mapper.keys(), deriv_order=[0 for _ in mapper],
734+
fd_order=[self.interp_order for _ in mapper],
735+
x0=mapper)
736+
737+
# Return the full expression with Derivatives
738+
return self.func(derivs, interp)
739+
672740

673741
class Pow(DifferentiableOp, sympy.Pow):
674742
_fd_priority = 0
@@ -1020,7 +1088,7 @@ def _subs(self, old, new, **hints):
10201088

10211089
class DiffDerivative(IndexDerivative, DifferentiableOp):
10221090

1023-
def _eval_at(self, func):
1091+
def _eval_at(self, func, **kwargs):
10241092
# Like EvalDerivative, a DiffDerivative must have already been evaluated
10251093
# at a valid x0 and should not be re-evaluated at a different location
10261094
return self
@@ -1074,7 +1142,7 @@ def _new_rawargs(self, *args, **kwargs):
10741142
kwargs.pop('is_commutative', None)
10751143
return self.func(*args, **kwargs)
10761144

1077-
def _eval_at(self, func):
1145+
def _eval_at(self, func, **kwargs):
10781146
# An EvalDerivative must have already been evaluated at a valid x0
10791147
# and should not be re-evaluated at a different location
10801148
return self

devito/operator/operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def _lower_exprs(cls, expressions, **kwargs):
342342
* Shift indices for domain alignment.
343343
"""
344344
expand = kwargs['options'].get('expand', True)
345+
mul_first = kwargs['options'].get('eval-mul-first', False)
345346

346347
# Specialization is performed on unevaluated expressions
347348
expressions = cls._specialize_dsl(expressions, **kwargs)
@@ -352,7 +353,8 @@ def _lower_exprs(cls, expressions, **kwargs):
352353
# ModuloDimensions
353354
if not expand:
354355
expand = lambda d: d.is_Stepping
355-
expressions = flatten([i._evaluate(expand=expand) for i in expressions])
356+
expressions = flatten([i._evaluate(expand=expand, mul_first=mul_first)
357+
for i in expressions])
356358

357359
# Scalarize the tensor equations, if any
358360
expressions = [j for i in expressions for j in i._flatten]

devito/types/dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ def __fd_setup__(self):
11161116

11171117
@cached_property
11181118
def _fd_priority(self):
1119-
return 1 if self.staggered.on_node else 2
1119+
return 1.2 if self.staggered.on_node else 1.1
11201120

11211121
def _eval_at(self, func):
11221122
if self.staggered == func.staggered or self.interp_order == 0:
@@ -1545,7 +1545,7 @@ def __shape_setup__(cls, **kwargs):
15451545

15461546
@cached_property
15471547
def _fd_priority(self):
1548-
return 2.1 if self.staggered.on_node else 2.2
1548+
return 2.1 if self.staggered.on_node else 2
15491549

15501550
@property
15511551
def time_order(self):

devito/types/equation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _evaluate(self, **kwargs):
110110
"""
111111
try:
112112
lhs = self.lhs._evaluate(**kwargs)
113-
rhs = self.rhs._eval_at(self.lhs)._evaluate(**kwargs)
113+
rhs = self.rhs._eval_at(self.lhs, **kwargs)._evaluate(**kwargs)
114114
except AttributeError:
115115
lhs, rhs = self._evaluate_args(**kwargs)
116116
eq = self.func(lhs, rhs, subdomain=self.subdomain,

devito/types/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def _dist_scatter(self, alias=None, data=None):
695695
mapper.update(self._dist_subfunc_scatter(sf))
696696
return mapper
697697

698-
def _eval_at(self, func):
698+
def _eval_at(self, func, **kwargs):
699699
return self
700700

701701
def _halo_exchange(self):

devito/types/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,13 @@ def __getattr__(self, name):
170170
f'{self.__class__!r} object has no attribute {name!r}'
171171
) from e
172172

173-
def _eval_at(self, func):
173+
def _eval_at(self, func, **kwargs):
174174
"""
175175
Evaluate tensor at func location
176176
"""
177177
def entries(i, j, func):
178-
return getattr(self[i, j], '_eval_at', lambda x: self[i, j])(func[i, j])
178+
return getattr(self[i, j], '_eval_at',
179+
lambda x: self[i, j])(func[i, j], **kwargs)
179180
entry = lambda i, j: entries(i, j, func)
180181
return self._new(self.rows, self.cols, entry)
181182

0 commit comments

Comments
 (0)