@@ -185,17 +185,11 @@ def coefficients(self):
185185 return sorted (coefficients , key = key , reverse = True )[0 ]
186186
187187 def _eval_at (self , func , ** kwargs ):
188- if not func .is_Staggered :
188+ if not func .is_Staggered and not self . is_Staggered :
189189 # Cartesian grid, do no waste time
190190 return self
191- < << << << HEAD
192- return self .func (* [
193- getattr (a , '_eval_at' , lambda x : a )(func ) for a in self .args # noqa: B023
194- ]) # false positive
195- == == == =
196191 return self .func (* [getattr (a , '_eval_at' , lambda x , ** kw : a )(func , ** kwargs )
197192 for a in self .args ])
198- >> >> >> > 198 aedce4 (api : add mul interp mode )
199193
200194 def _subs (self , old , new , ** hints ):
201195 if old == self :
@@ -497,14 +491,19 @@ def has_free(self, *patterns):
497491 return all (i in self .free_symbols for i in patterns )
498492
499493
494+ < << << << HEAD
500495def highest_priority (diff_op ):
501496 if not diff_op ._args_diff :
502497 return diff_op
503498
499+ == == == =
500+ def highest_priority (DiffOp , ref = None ):
501+ > >> >> >> c34bae6f0 (tweak mul mode )
504502 # We want to get the object with highest priority
505503 # We also need to make sure that the object with the largest
506504 # set of dimensions is used when multiple ones with the same
507505 # priority appear
506+ < << << << HEAD
508507 prio = lambda x : (getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
509508< << << << HEAD
510509 prio_func = sorted (diff_op ._args_diff , key = prio , reverse = True )[0 ]
@@ -514,6 +513,14 @@ def highest_priority(diff_op):
514513 return highest_priority (prio_func )
515514 return prio_func
516515== == == =
516+ == == == =
517+ def stagg (x ):
518+ try :
519+ return int (x .staggered == ref .staggered )
520+ except AttributeError :
521+ return 0
522+ prio = lambda x : (stagg (x ), getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
523+ >> >> >> > c34bae6f0 (tweak mul mode )
517524 args = DiffOp ._args_diff
518525 if not args :
519526 return DiffOp
@@ -692,50 +699,62 @@ def _eval_at(self, func, mul_first=False, **kwargs):
692699 if not mul_first :
693700 return super ()._eval_at (func , mul_first = mul_first )
694701
695- # Not a basic a*b*c... expression, just defer to superclass
702+ # Same staggering, no need to interpolate
703+ if self .staggered == func .staggered :
704+ return self
705+
706+ # Get highest priority function for evaluation
707+ func0 = highest_priority (self , ref = func )
708+
709+ # Not a basic a*b*c... expression, expand
696710 if any (isinstance (f , DifferentiableOp ) for f in self .args ):
697- return super ( )._eval_at (func , mul_first = mul_first )
711+ return diffify ( self . _eval_expand_mul () )._eval_at (func , mul_first = mul_first )
698712
699713 # Split Derivative and Differentiable args
700714 derivs , other = split (self .args , lambda e : isinstance (e , sympy .Derivative ))
701715
716+ # Evaluate all at highest priority function
702717 if derivs :
703- derivs = Differentiable ._eval_at (self .func (* derivs ), func ,
704- mul_first = mul_first )
718+ derivs = self .func (* [d ._eval_at (func , mul_first = mul_first ) for d in derivs ])
705719 else :
706720 derivs = 1
707721
708722 if not other :
709723 return derivs
710- elif len (other ) > 1 :
711- expr = self .func (* other )._gather_for_diff
712- else :
713- expr = other [0 ]
724+ expr = self .func (* other )
714725
715726 # Non differentiable expr (e.g., number)
716727 if not isinstance (expr , Differentiable ):
717728 return self .func (derivs , expr )
718729
719- # Build mapper for dimensions that need to be interpolated
720- mapper = {}
721- for d in self .dimensions :
722- try :
723- if self .indices_ref [d ] is not func .indices_ref [d ]:
724- mapper [d ] = func .indices_ref [d ]
725- except KeyError :
726- pass
727-
728- # Nothing to interpolate
729- if not mapper :
730- return super ()._eval_at (func , mul_first = mul_first )
731-
732- # Interpolate expr at the required indices
733- interp = expr .diff (* mapper .keys (), deriv_order = [0 for _ in mapper ],
734- fd_order = [self .interp_order for _ in mapper ],
735- x0 = mapper )
736-
737- # Return the full expression with Derivatives
738- return self .func (derivs , interp )
730+ # Evaluate expression at func_args
731+ print (f"\n Evaluating expr { expr } at func0 { func0 } for func { func } from { self } " )
732+ expr = Differentiable ._eval_at (expr , func0 , mul_first = False )
733+
734+ # Interpolate derivatives at func0
735+ x0 = {d : v for d , v in func0 .indices_ref .getters .items ()
736+ if not d .is_Time and v is not func .indices_ref .getters .get (d , d )}
737+ if x0 and not derivs == 1 :
738+ print (f"Interpolating derivs { derivs } x0={ derivs .x0 } at { x0 } " )
739+ derivs = derivs .diff (* x0 .keys (), deriv_order = (0 ,)* len (x0 ),
740+ fd_order = (self .interp_order ,)* len (x0 ),
741+ x0 = x0 )
742+ newexpr = self .func (derivs , expr )
743+
744+ # Finally at func
745+ if not func .staggered == func0 .staggered :
746+ x0_f = {d : v for d , v in func .indices_ref .getters .items ()
747+ if not d .is_Time and v is not func0 .indices_ref .getters .get (d )}
748+ if x0_f :
749+ print (f"Final interpolation of derivs { self .func (derivs , expr )} at func { x0_f } " )
750+ return newexpr .diff (* x0_f .keys (), deriv_order = (0 ,)* len (x0_f ),
751+ fd_order = (self .interp_order ,)* len (x0_f ),
752+ x0 = x0_f )
753+ else :
754+ return newexpr
755+ else :
756+ # Return the full expression with Derivatives
757+ return newexpr
739758
740759
741760class Pow (DifferentiableOp , sympy .Pow ):
0 commit comments