From 90f3de7f158679fd61df65e6738aaf341f314843 Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Thu, 19 Mar 2026 07:53:40 -0700 Subject: [PATCH] Use paralell blend chain path in FastSigmoid on architectures having >=32 registers PiperOrigin-RevId: 886178215 --- ops/fast_ops-inl.h | 195 +++++++++++++++++++++++++++++++-------------- 1 file changed, 135 insertions(+), 60 deletions(-) diff --git a/ops/fast_ops-inl.h b/ops/fast_ops-inl.h index a3c2051b..2bfb8891 100644 --- a/ops/fast_ops-inl.h +++ b/ops/fast_ops-inl.h @@ -146,66 +146,141 @@ HWY_INLINE hn::Vec FastSigmoid(D d, hn::Vec val) { const auto t5 = hn::Set(d, static_cast(3.288402547357102)); const auto t6 = hn::Set(d, static_cast(5.271780018997146)); - // Start with highest index (7) - b = hn::Set(d, static_cast(-4.688832585616333)); - c = hn::Set(d, static_cast(1.9985234759675707)); - d_coef = hn::Set(d, static_cast(-9.357047249878605)); - - // If y < t6 (idx 6) - auto mask = hn::Lt(y, t6); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-2.0824831112860647)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.9363640518503402)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-3.767209600467866)), d_coef); - - // If y < t5 (idx 5) - mask = hn::Lt(y, t5); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.9556349519550872)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.757582383623199)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-0.7540678688218365)), d_coef); - - // If y < t4 (idx 4) - mask = hn::Lt(y, t4); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.42437007298661206)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.4909222917402543)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(1.1565092321921886)), d_coef); - - // If y < t3 (idx 3) - mask = hn::Lt(y, t3); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.16943664192343108)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.1775629610724903)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(2.4240814251983283)), d_coef); - - // If y < t2 (idx 2) - mask = hn::Lt(y, t2); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.05367999021047822)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.8423809865207907)), c); - d_coef = hn::IfThenElse(mask, hn::Set(d, static_cast(3.253860706225495)), - d_coef); - - // If y < t1 (idx 1) - mask = hn::Lt(y, t1); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.010315055591476996)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.5069204289218385)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(3.7450486139396544)), d_coef); - - // If y < t0 (idx 0) - mask = hn::Lt(y, t0); - b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.0006967055197996615)), - b); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.220551955463595)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(3.9548607753775276)), d_coef); + if constexpr (HWY_REGISTERS >= 32) { + // Split into two parallel chains to reduce dependency latency. + + // -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0) + auto b_low = hn::Set(d, static_cast(-0.16943664192343108)); // idx 3 + auto c_low = hn::Set(d, static_cast(1.1775629610724903)); + auto d_low = hn::Set(d, static_cast(2.4240814251983283)); + + auto mask = hn::Lt(y, t2); + b_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.05367999021047822)), b_low); + c_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(0.8423809865207907)), c_low); + d_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(3.253860706225495)), d_low); + + mask = hn::Lt(y, t1); + b_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.010315055591476996)), b_low); + c_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(0.5069204289218385)), c_low); + d_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(3.7450486139396544)), d_low); + + mask = hn::Lt(y, t0); + b_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.0006967055197996615)), b_low); + c_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(0.220551955463595)), c_low); + d_low = hn::IfThenElse( + mask, hn::Set(d, static_cast(3.9548607753775276)), d_low); + + // -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4) + auto b_high = hn::Set(d, static_cast(-4.688832585616333)); // idx 7 + auto c_high = hn::Set(d, static_cast(1.9985234759675707)); + auto d_high = hn::Set(d, static_cast(-9.357047249878605)); + + mask = hn::Lt(y, t6); + b_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(-2.0824831112860647)), b_high); + c_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(1.9363640518503402)), c_high); + d_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(-3.767209600467866)), d_high); + + mask = hn::Lt(y, t5); + b_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.9556349519550872)), b_high); + c_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(1.757582383623199)), c_high); + d_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.7540678688218365)), d_high); + + mask = hn::Lt(y, t4); + b_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.42437007298661206)), b_high); + c_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(1.4909222917402543)), c_high); + d_high = hn::IfThenElse( + mask, hn::Set(d, static_cast(1.1565092321921886)), d_high); + + // -- Merge the two chains + auto merge_mask = hn::Lt(y, t3); + b = hn::IfThenElse(merge_mask, b_low, b_high); + c = hn::IfThenElse(merge_mask, c_low, c_high); + d_coef = hn::IfThenElse(merge_mask, d_low, d_high); + } else { + // Start with highest index (7) + b = hn::Set(d, static_cast(-4.688832585616333)); + c = hn::Set(d, static_cast(1.9985234759675707)); + d_coef = hn::Set(d, static_cast(-9.357047249878605)); + + // If y < t6 (idx 6) + auto mask = hn::Lt(y, t6); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-2.0824831112860647)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.9363640518503402)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-3.767209600467866)), d_coef); + + // If y < t5 (idx 5) + mask = hn::Lt(y, t5); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.9556349519550872)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.757582383623199)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-0.7540678688218365)), d_coef); + + // If y < t4 (idx 4) + mask = hn::Lt(y, t4); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.42437007298661206)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.4909222917402543)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(1.1565092321921886)), d_coef); + + // If y < t3 (idx 3) + mask = hn::Lt(y, t3); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.16943664192343108)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.1775629610724903)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(2.4240814251983283)), d_coef); + + // If y < t2 (idx 2) + mask = hn::Lt(y, t2); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.05367999021047822)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.8423809865207907)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(3.253860706225495)), d_coef); + + // If y < t1 (idx 1) + mask = hn::Lt(y, t1); + b = hn::IfThenElse(mask, + hn::Set(d, static_cast(-0.010315055591476996)), b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.5069204289218385)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(3.7450486139396544)), d_coef); + + // If y < t0 (idx 0) + mask = hn::Lt(y, t0); + b = hn::IfThenElse(mask, + hn::Set(d, static_cast(-0.0006967055197996615)), b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.220551955463595)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(3.9548607753775276)), d_coef); + } } // Math: 0.5 * tanh(y/2) = (y + b)/(cy + d_coef)