From 91320f6c8ec91fb1c50391e2fc6caeb555f343a5 Mon Sep 17 00:00:00 2001 From: Nanook Claw Date: Tue, 19 May 2026 04:41:27 +0000 Subject: [PATCH] fix: use subtle for constant-time comparison --- crates/daphne/src/messages/mod.rs | 21 ++++++++++----------- crates/daphne/src/vdaf/prio3.rs | 16 +++++++++++++++- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index 462f34bfc..821101f69 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -22,6 +22,7 @@ use std::{ fmt, io::{Cursor, Read}, }; +use subtle::ConstantTimeEq; // Batch modes const BATCH_MODE_TIME_INTERVAL: u8 = 0x01; @@ -1485,18 +1486,8 @@ impl ParameterizedDecode for PlaintextInputShare { } } -// NOTE ring provides a similar function, but as of version 0.16.20, it doesn't compile to -// wasm32-unknown-unknown. pub fn constant_time_eq(left: &[u8], right: &[u8]) -> bool { - if left.len() != right.len() { - return false; - } - - let mut r = 0; - for (x, y) in left.iter().zip(right) { - r |= x ^ y; - } - r == 0 + bool::from(left.ct_eq(right)) } pub(crate) fn encode_u16_bytes(bytes: &mut Vec, input: &[u8]) -> Result<(), CodecError> { @@ -1705,6 +1696,14 @@ mod test { )); } + #[test] + fn constant_time_eq_matches_slice_equality() { + assert!(constant_time_eq(b"", b"")); + assert!(constant_time_eq(b"same bytes", b"same bytes")); + assert!(!constant_time_eq(b"same bytes", b"same bytez")); + assert!(!constant_time_eq(b"short", b"shorter")); + } + fn partial_batch_selector_encode_decode(version: DapVersion) { const TEST_DATA_DRAFT09: &[u8] = &[1]; const TEST_DATA_LATEST: &[u8] = &[1, 0, 0]; diff --git a/crates/daphne/src/vdaf/prio3.rs b/crates/daphne/src/vdaf/prio3.rs index 461f8af95..263e39b03 100644 --- a/crates/daphne/src/vdaf/prio3.rs +++ b/crates/daphne/src/vdaf/prio3.rs @@ -22,6 +22,7 @@ use prio::{ }, }; use std::io::Cursor; +use subtle::ConstantTimeEq; impl Prio3Config { pub(crate) fn shard( @@ -35,10 +36,11 @@ impl Prio3Config { (DapVersion::Latest, Prio3Config::Count, DapMeasurement::U64(measurement)) if measurement < 2 => { + let measurement = bool::from(measurement.ct_eq(&1)); let vdaf = Prio3::new_count(2).map_err(|e| { VdafError::Dap(fatal_error!(err = ?e, "initializing {self:?} failed")) })?; - shard_then_encode(&vdaf, task_id, &(measurement != 0), nonce) + shard_then_encode(&vdaf, task_id, &measurement, nonce) } ( DapVersion::Latest, @@ -612,6 +614,18 @@ mod test { assert_eq!(got, DapAggregateResult::U64(3)); } + #[test] + fn count_rejects_non_binary_measurement() { + let got = Prio3Config::Count.shard( + DapVersion::Latest, + DapMeasurement::U64(2), + &[0; 16], + crate::messages::TaskId([0; 32]), + ); + + assert!(got.is_err()); + } + #[test] fn roundtrip_sum() { let mut t = AggregationJobTest::new(