|
1 | 1 | from collections import ChainMap |
| 2 | +from contextlib import suppress |
2 | 3 | from functools import cached_property, singledispatch |
3 | 4 | from itertools import product |
4 | 5 |
|
@@ -185,8 +186,10 @@ def coefficients(self): |
185 | 186 | return sorted(coefficients, key=key, reverse=True)[0] |
186 | 187 |
|
187 | 188 | def _eval_at(self, func, **kwargs): |
188 | | - return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) |
189 | | - for a in self.args]) |
| 189 | + return self.func(*[ |
| 190 | + getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) # noqa: B023 |
| 191 | + for a in self.args # false positive: lambda is invoked in-place |
| 192 | + ]) |
190 | 193 |
|
191 | 194 | def _subs(self, old, new, **hints): |
192 | 195 | if old == self: |
@@ -576,9 +579,6 @@ def _fd_priority(self): |
576 | 579 | return super()._fd_priority |
577 | 580 | return highest_priority(self)._fd_priority |
578 | 581 |
|
579 | | - def _eval_at(self, func, **kwargs): |
580 | | - return self |
581 | | - |
582 | 582 |
|
583 | 583 | class Add(DifferentiableOp, sympy.Add): |
584 | 584 | __sympy_class__ = sympy.Add |
@@ -668,67 +668,63 @@ def _gather_for_diff(self): |
668 | 668 | other = self.func(*other)._eval_at(highest_priority(self)) |
669 | 669 | return self.func(other, *derivs) |
670 | 670 |
|
671 | | - def _eval_at(self, func, mul_first=False, **kwargs): |
672 | | - # Dont evaluate mul first |
673 | | - if not mul_first: |
674 | | - return super()._eval_at(func, mul_first=mul_first) |
675 | | - |
676 | | - # Same staggering, no need to interpolate |
677 | | - if self.staggered == func.staggered: |
678 | | - return self |
679 | | - |
680 | | - # Get highest priority function for evaluation |
681 | | - func0 = highest_priority(self, ref=func) |
682 | | - |
683 | | - # Not a basic a*b*c... expression, expand |
684 | | - if any(isinstance(f, DifferentiableOp) for f in self.args): |
685 | | - return diffify(self._eval_expand_mul())._eval_at(func, mul_first=mul_first) |
686 | | - |
687 | | - # Split Derivative and Differentiable args |
688 | | - derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative)) |
689 | | - |
690 | | - # Evaluate all at highest priority function |
691 | | - if derivs: |
692 | | - derivs = self.func(*[d._eval_at(func, mul_first=mul_first) for d in derivs]) |
693 | | - else: |
694 | | - derivs = 1 |
695 | | - |
696 | | - if not other: |
697 | | - return derivs |
698 | | - expr = self.func(*other) |
699 | | - |
700 | | - # Non differentiable expr (e.g., number) |
701 | | - if not isinstance(expr, Differentiable): |
702 | | - return self.func(derivs, expr) |
703 | | - |
704 | | - # Evaluate expression at func_args |
705 | | - print(f"\nEvaluating expr {expr} at func0 {func0} for func {func} from {self}") |
706 | | - expr = Differentiable._eval_at(expr, func0, mul_first=False) |
707 | | - |
708 | | - # Interpolate derivatives at func0 |
709 | | - x0 = {d: v for d, v in func0.indices_ref.getters.items() |
710 | | - if not d.is_Time and v is not func.indices_ref.getters.get(d, d)} |
711 | | - if x0 and not derivs == 1: |
712 | | - print(f"Interpolating derivs {derivs} x0={derivs.x0} at {x0}") |
713 | | - derivs = derivs.diff(*x0.keys(), deriv_order=(0,)*len(x0), |
714 | | - fd_order=(self.interp_order,)*len(x0), |
715 | | - x0=x0) |
716 | | - newexpr = self.func(derivs, expr) |
717 | | - |
718 | | - # Finally at func |
719 | | - if not func.staggered == func0.staggered: |
720 | | - x0_f = {d: v for d, v in func.indices_ref.getters.items() |
721 | | - if not d.is_Time and v is not func0.indices_ref.getters.get(d)} |
722 | | - if x0_f: |
723 | | - print(f"Final interpolation of derivs {self.func(derivs, expr)} at func {x0_f}") |
724 | | - return newexpr.diff(*x0_f.keys(), deriv_order=(0,)*len(x0_f), |
725 | | - fd_order=(self.interp_order,)*len(x0_f), |
726 | | - x0=x0_f) |
| 671 | + def _eval_at(self, func, interp_mode='direct', **kwargs): |
| 672 | + """ |
| 673 | + Evaluate a Mul at the location of `func`. |
| 674 | +
|
| 675 | + Two modes: |
| 676 | +
|
| 677 | + - `interp_mode='direct'` (default): per-arg evaluation; each factor is |
| 678 | + independently evaluated at `func`'s location via |
| 679 | + `Differentiable._eval_at`. |
| 680 | +
|
| 681 | + - `interp_mode='symmetric'`: when every Differentiable factor has a |
| 682 | + staggering different from `func`'s, apply the `I * (a * I^T * b)` |
| 683 | + form: |
| 684 | +
|
| 685 | + 1. Pick a `block` location -- the highest-priority factor's |
| 686 | + staggering (NODE is the highest priority, so coefficient-like |
| 687 | + NODE factors win, as in the `I * C * I^T` elastic stiffness |
| 688 | + pattern). Each factor not at the block is brought there via |
| 689 | + `I^T` (an explicit 0-order FD interpolation operator). |
| 690 | + Derivatives additionally set `x0` on their own derivative |
| 691 | + dimensions to `func`'s indices. |
| 692 | + 2. The product is formed at `block`'s location. |
| 693 | + 3. The whole product is interpolated to `func` via `I` (an |
| 694 | + explicit 0-order FD operator). |
| 695 | +
|
| 696 | + When the trigger does not hold (e.g. some factor already matches |
| 697 | + `func`'s staggering), we fall back to `direct`. |
| 698 | + """ |
| 699 | + if interp_mode != 'symmetric': |
| 700 | + return super()._eval_at(func, **kwargs) |
| 701 | + |
| 702 | + diff_args = [a for a in self.args if isinstance(a, Differentiable)] |
| 703 | + other_args = [a for a in self.args if not isinstance(a, Differentiable)] |
| 704 | + |
| 705 | + # Symmetric form requires every Differentiable factor to differ from |
| 706 | + # func; otherwise direct evaluation is cleaner and equivalent. |
| 707 | + if len(diff_args) < 2 or \ |
| 708 | + any(a.staggered == func.staggered for a in diff_args): |
| 709 | + return super()._eval_at(func, **kwargs) |
| 710 | + |
| 711 | + block_indices = highest_priority(self).indices_ref |
| 712 | + |
| 713 | + # Bring each factor to block's location (I^T where needed) |
| 714 | + new_factors = list(other_args) |
| 715 | + for a in diff_args: |
| 716 | + if isinstance(a, sympy.Derivative): |
| 717 | + source = _post_x0_indices(a, func) |
| 718 | + a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims |
| 719 | + if dim in func.indices_ref.getters}) |
727 | 720 | else: |
728 | | - return newexpr |
729 | | - else: |
730 | | - # Return the full expression with Derivatives |
731 | | - return newexpr |
| 721 | + source = a.indices_ref |
| 722 | + new_factors.append(_interp_at(a, source, block_indices, |
| 723 | + self.interp_order)) |
| 724 | + |
| 725 | + # Final I from block's location to func |
| 726 | + return _interp_at(self.func(*new_factors), block_indices, |
| 727 | + func.indices_ref, self.interp_order) |
732 | 728 |
|
733 | 729 |
|
734 | 730 | class Pow(DifferentiableOp, sympy.Pow): |
@@ -1255,6 +1251,63 @@ def _diff2sympy(obj): |
1255 | 1251 | evalf_table[Pow] = evalf_table[sympy.Pow] |
1256 | 1252 |
|
1257 | 1253 |
|
| 1254 | +def _interp_mapper(source, target, dims): |
| 1255 | + """ |
| 1256 | + Build a `{dim: target_index}` mapper for dimensions in `dims` where |
| 1257 | + `source[dim]` differs from `target[dim]`. |
| 1258 | +
|
| 1259 | + `source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain |
| 1260 | + dict or a `DimensionTuple`). Dimensions missing from either side are |
| 1261 | + skipped silently. |
| 1262 | + """ |
| 1263 | + mapper = {} |
| 1264 | + for d in dims: |
| 1265 | + try: |
| 1266 | + s = source[d] |
| 1267 | + t = target[d] |
| 1268 | + except (KeyError, IndexError): |
| 1269 | + continue |
| 1270 | + if s is not t: |
| 1271 | + mapper[d] = t |
| 1272 | + return mapper |
| 1273 | + |
| 1274 | + |
| 1275 | +def _interp_at(expr, source, target, interp_order): |
| 1276 | + """ |
| 1277 | + Build a symbolic 0-order FD interpolation operator on `expr` that maps |
| 1278 | + values from `source` indices to `target` indices, only on the |
| 1279 | + dimensions where the two locations differ. |
| 1280 | + """ |
| 1281 | + if not isinstance(expr, Differentiable): |
| 1282 | + return expr |
| 1283 | + |
| 1284 | + mapper = _interp_mapper(source, target, expr.dimensions) |
| 1285 | + if not mapper: |
| 1286 | + return expr |
| 1287 | + |
| 1288 | + return expr.diff(*mapper.keys(), |
| 1289 | + deriv_order=(0,) * len(mapper), |
| 1290 | + fd_order=(interp_order,) * len(mapper), |
| 1291 | + x0=mapper) |
| 1292 | + |
| 1293 | + |
| 1294 | +def _post_x0_indices(deriv, func): |
| 1295 | + """ |
| 1296 | + Conceptual indices of `deriv` after setting `x0` on its own derivative |
| 1297 | + dimensions to `func`'s indices. Derivative dims take `func`'s indices; |
| 1298 | + other dims keep the underlying expression's natural location (so that |
| 1299 | + `interp_for_fd` does not introduce a spurious second shift). |
| 1300 | + """ |
| 1301 | + ref = {} |
| 1302 | + for dim in deriv.dimensions: |
| 1303 | + if dim in deriv.dims and dim in func.indices_ref.getters: |
| 1304 | + ref[dim] = func.indices_ref[dim] |
| 1305 | + else: |
| 1306 | + with suppress(KeyError): |
| 1307 | + ref[dim] = deriv.indices_ref[dim] |
| 1308 | + return ref |
| 1309 | + |
| 1310 | + |
1258 | 1311 | # Interpolation for finite differences |
1259 | 1312 | @singledispatch |
1260 | 1313 | def interp_for_fd(expr, x0, **kwargs): |
|
0 commit comments