@@ -184,13 +184,18 @@ def coefficients(self):
184184 key = lambda x : coeff_priority .get (x , - 1 )
185185 return sorted (coefficients , key = key , reverse = True )[0 ]
186186
187- def _eval_at (self , func ):
187+ def _eval_at (self , func , ** kwargs ):
188188 if not func .is_Staggered :
189189 # Cartesian grid, do no waste time
190190 return self
191+ < << << << HEAD
191192 return self .func (* [
192193 getattr (a , '_eval_at' , lambda x : a )(func ) for a in self .args # noqa: B023
193194 ]) # false positive
195+ == == == =
196+ return self .func (* [getattr (a , '_eval_at' , lambda x , ** kw : a )(func , ** kwargs )
197+ for a in self .args ])
198+ >> >> >> > 198 aedce4 (api : add mul interp mode )
194199
195200 def _subs (self , old , new , ** hints ):
196201 if old == self :
@@ -501,12 +506,20 @@ def highest_priority(diff_op):
501506 # set of dimensions is used when multiple ones with the same
502507 # priority appear
503508 prio = lambda x : (getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
509+ < << << << HEAD
504510 prio_func = sorted (diff_op ._args_diff , key = prio , reverse = True )[0 ]
505511
506512 # The highest priority must be a Function
507513 if not isinstance (prio_func , AbstractFunction ):
508514 return highest_priority (prio_func )
509515 return prio_func
516+ == == == =
517+ args = DiffOp ._args_diff
518+ if not args :
519+ return DiffOp
520+ else :
521+ return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
522+ >> >> >> > 198 aedce4 (api : add mul interp mode )
510523
511524
512525class DifferentiableOp (Differentiable ):
@@ -574,11 +587,16 @@ class DifferentiableFunction(DifferentiableOp):
574587 def __new__ (cls , * args , ** kwargs ):
575588 return cls .__sympy_class__ .__new__ (cls , * args , ** kwargs )
576589
590+ < << << << HEAD
577591 @property
578592 def _fd_priority (self ):
579593 if highest_priority (self ) is self :
580594 return super ()._fd_priority
581595 return highest_priority (self )._fd_priority
596+ == == == =
597+ def _eval_at (self , func , ** kwargs ):
598+ return self
599+ >> >> >> > 198 aedce4 (api : add mul interp mode )
582600
583601
584602class Add (DifferentiableOp , sympy .Add ):
@@ -669,6 +687,56 @@ def _gather_for_diff(self):
669687 other = self .func (* other )._eval_at (highest_priority (self ))
670688 return self .func (other , * derivs )
671689
690+ def _eval_at (self , func , mul_first = False , ** kwargs ):
691+ # Dont evaluate mul first
692+ if not mul_first :
693+ return super ()._eval_at (func , mul_first = mul_first )
694+
695+ # Not a basic a*b*c... expression, just defer to superclass
696+ if any (isinstance (f , DifferentiableOp ) for f in self .args ):
697+ return super ()._eval_at (func , mul_first = mul_first )
698+
699+ # Split Derivative and Differentiable args
700+ derivs , other = split (self .args , lambda e : isinstance (e , sympy .Derivative ))
701+
702+ if derivs :
703+ derivs = Differentiable ._eval_at (self .func (* derivs ), func ,
704+ mul_first = mul_first )
705+ else :
706+ derivs = 1
707+
708+ if not other :
709+ return derivs
710+ elif len (other ) > 1 :
711+ expr = self .func (* other )._gather_for_diff
712+ else :
713+ expr = other [0 ]
714+
715+ # Non differentiable expr (e.g., number)
716+ if not isinstance (expr , Differentiable ):
717+ return self .func (derivs , expr )
718+
719+ # Build mapper for dimensions that need to be interpolated
720+ mapper = {}
721+ for d in self .dimensions :
722+ try :
723+ if self .indices_ref [d ] is not func .indices_ref [d ]:
724+ mapper [d ] = func .indices_ref [d ]
725+ except KeyError :
726+ pass
727+
728+ # Nothing to interpolate
729+ if not mapper :
730+ return super ()._eval_at (func , mul_first = mul_first )
731+
732+ # Interpolate expr at the required indices
733+ interp = expr .diff (* mapper .keys (), deriv_order = [0 for _ in mapper ],
734+ fd_order = [self .interp_order for _ in mapper ],
735+ x0 = mapper )
736+
737+ # Return the full expression with Derivatives
738+ return self .func (derivs , interp )
739+
672740
673741class Pow (DifferentiableOp , sympy .Pow ):
674742 _fd_priority = 0
@@ -1020,7 +1088,7 @@ def _subs(self, old, new, **hints):
10201088
10211089class DiffDerivative (IndexDerivative , DifferentiableOp ):
10221090
1023- def _eval_at (self , func ):
1091+ def _eval_at (self , func , ** kwargs ):
10241092 # Like EvalDerivative, a DiffDerivative must have already been evaluated
10251093 # at a valid x0 and should not be re-evaluated at a different location
10261094 return self
@@ -1074,7 +1142,7 @@ def _new_rawargs(self, *args, **kwargs):
10741142 kwargs .pop ('is_commutative' , None )
10751143 return self .func (* args , ** kwargs )
10761144
1077- def _eval_at (self , func ):
1145+ def _eval_at (self , func , ** kwargs ):
10781146 # An EvalDerivative must have already been evaluated at a valid x0
10791147 # and should not be re-evaluated at a different location
10801148 return self
0 commit comments