@@ -231,6 +231,11 @@ pub fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 {
231231///
232232/// The points are given as sequences of coordinates.
233233/// Uses high-precision vector_norm algorithm.
234+ ///
235+ /// Panics if `p` and `q` have different lengths. CPython raises ValueError
236+ /// for mismatched dimensions, but in this Rust API the caller is expected
237+ /// to guarantee equal-length slices. A length mismatch is a programming
238+ /// error, not a runtime condition.
234239pub fn dist ( p : & [ f64 ] , q : & [ f64 ] ) -> f64 {
235240 assert_eq ! (
236241 p. len( ) ,
@@ -261,24 +266,52 @@ pub fn dist(p: &[f64], q: &[f64]) -> f64 {
261266
262267/// Return the sum of products of values from two sequences (float version).
263268///
264- /// Uses TripleLength arithmetic for high precision.
265- /// Equivalent to sum(p[i] * q[i] for i in range(len(p))).
266- pub fn sumprod ( p : & [ f64 ] , q : & [ f64 ] ) -> f64 {
267- assert_eq ! ( p. len( ) , q. len( ) , "Inputs are not the same length" ) ;
269+ /// Uses TripleLength arithmetic for the fast path, then falls back to
270+ /// ordinary floating-point multiply/add starting at the first unsupported
271+ /// pair, matching Python's staged `math.sumprod` behavior for float inputs.
272+ ///
273+ /// CPython's math_sumprod_impl is a 3-stage state machine that handles
274+ /// int/float/generic Python objects. This function only covers the float
275+ /// path (`&[f64]`). The int accumulation and generic PyNumber fallback
276+ /// stages are Python type-system concerns and should be handled by the
277+ /// caller (e.g. RustPython) before delegating here.
278+ ///
279+ /// Returns EDOM if the inputs are not the same length.
280+ pub fn sumprod ( p : & [ f64 ] , q : & [ f64 ] ) -> crate :: Result < f64 > {
281+ if p. len ( ) != q. len ( ) {
282+ return Err ( crate :: Error :: EDOM ) ;
283+ }
268284
285+ let mut total = 0.0 ;
269286 let mut flt_total = TL_ZERO ;
287+ let mut flt_path_enabled = true ;
288+ let mut i = 0 ;
270289
271- for ( & pi, & qi) in p. iter ( ) . zip ( q. iter ( ) ) {
272- let new_flt_total = tl_fma ( pi, qi, flt_total) ;
273- if new_flt_total. hi . is_finite ( ) {
274- flt_total = new_flt_total;
275- } else {
276- // Overflow or special value, fall back to simple sum
277- return p. iter ( ) . zip ( q. iter ( ) ) . map ( |( a, b) | a * b) . sum ( ) ;
290+ while i < p. len ( ) {
291+ let pi = p[ i] ;
292+ let qi = q[ i] ;
293+
294+ if flt_path_enabled {
295+ let new_flt_total = tl_fma ( pi, qi, flt_total) ;
296+ if new_flt_total. hi . is_finite ( ) {
297+ flt_total = new_flt_total;
298+ i += 1 ;
299+ continue ;
300+ }
301+
302+ flt_path_enabled = false ;
303+ total += tl_to_d ( flt_total) ;
278304 }
305+
306+ total += pi * qi;
307+ i += 1 ;
279308 }
280309
281- tl_to_d ( flt_total)
310+ Ok ( if flt_path_enabled {
311+ tl_to_d ( flt_total)
312+ } else {
313+ total
314+ } )
282315}
283316
284317/// Return the sum of products of values from two sequences (integer version).
@@ -427,14 +460,27 @@ mod tests {
427460 crate :: test:: with_py_math ( |py, math| {
428461 let py_p = pyo3:: types:: PyList :: new ( py, p) . unwrap ( ) ;
429462 let py_q = pyo3:: types:: PyList :: new ( py, q) . unwrap ( ) ;
430- let py: f64 = math
431- . getattr ( "sumprod" )
432- . unwrap ( )
433- . call1 ( ( py_p, py_q) )
434- . unwrap ( )
435- . extract ( )
436- . unwrap ( ) ;
437- crate :: test:: assert_f64_eq ( py, rs, format_args ! ( "sumprod({p:?}, {q:?})" ) ) ;
463+ let py_result = math. getattr ( "sumprod" ) . unwrap ( ) . call1 ( ( py_p, py_q) ) ;
464+ match py_result {
465+ Ok ( py_val) => {
466+ let py: f64 = py_val. extract ( ) . unwrap ( ) ;
467+ let rs = rs. unwrap_or_else ( |e| {
468+ panic ! ( "sumprod({p:?}, {q:?}): py={py} but rs returned error {e:?}" )
469+ } ) ;
470+ crate :: test:: assert_f64_eq ( py, rs, format_args ! ( "sumprod({p:?}, {q:?})" ) ) ;
471+ }
472+ Err ( e) => {
473+ if e. is_instance_of :: < pyo3:: exceptions:: PyValueError > ( py) {
474+ assert_eq ! (
475+ rs. as_ref( ) . err( ) ,
476+ Some ( & crate :: Error :: EDOM ) ,
477+ "sumprod({p:?}, {q:?}): py raised ValueError but rs={rs:?}"
478+ ) ;
479+ } else {
480+ panic ! ( "sumprod({p:?}, {q:?}): py raised unexpected error {e}" ) ;
481+ }
482+ }
483+ }
438484 } ) ;
439485 }
440486
@@ -444,6 +490,9 @@ mod tests {
444490 test_sumprod_impl ( & [ ] , & [ ] ) ;
445491 test_sumprod_impl ( & [ 1.0 ] , & [ 2.0 ] ) ;
446492 test_sumprod_impl ( & [ 1e100 , 1e100 ] , & [ 1e100 , -1e100 ] ) ;
493+ test_sumprod_impl ( & [ 1.0 , 1e308 , -1e308 ] , & [ 1.0 , 2.0 , 2.0 ] ) ;
494+ test_sumprod_impl ( & [ 1e-16 , 1e308 , -1e308 ] , & [ 1.0 , 2.0 , 2.0 ] ) ;
495+ test_sumprod_impl ( & [ 1.0 ] , & [ ] ) ;
447496 }
448497
449498 fn test_prod_impl ( values : & [ f64 ] , start : Option < f64 > ) {
0 commit comments