diff --git a/crates/backend/air/src/lib.rs b/crates/backend/air/src/lib.rs index 3a6a416a6..7f22b792b 100644 --- a/crates/backend/air/src/lib.rs +++ b/crates/backend/air/src/lib.rs @@ -14,6 +14,19 @@ pub trait Air: Send + Sync + 'static { fn degree_air(&self) -> usize; + /// True max constraint degree along a fold line. Defaults to `degree_air()`. + /// Override when `degree_air` overstates the true degree: the prover skips + /// the redundant top eval pass. Wire format unchanged (sized by `degree_air`). + fn degree_z(&self) -> usize { + self.degree_air() + } + + /// Whether the C2 per-pair constraint cache pays for itself. Both paths + /// are bit-identical — this flag only selects the faster one per table. + fn c2_table_profitable(&self) -> bool { + true + } + fn n_columns(&self) -> usize; fn n_constraints(&self) -> usize; @@ -24,6 +37,13 @@ pub trait Air: Send + Sync + 'static { fn n_shift_columns(&self) -> usize; fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData); + + /// Emit only bus constraints (for the C2 seed round). On valid rows all + /// non-bus constraints vanish, so the bus-only accumulator equals the full + /// one. Default = full eval (bit-identical, just slower). + fn eval_bus_only(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + self.eval(builder, extra_data); + } } pub trait AirBuilder: Sized { diff --git a/crates/backend/air/src/symbolic.rs b/crates/backend/air/src/symbolic.rs index cbc5283a3..63a42116f 100644 --- a/crates/backend/air/src/symbolic.rs +++ b/crates/backend/air/src/symbolic.rs @@ -259,6 +259,7 @@ struct SymbolicAirBuilder { constraints: Vec>, bus_multiplicity_value: Option>, bus_data_values: Option>>, + extra_declared_groups: Vec>>, } impl SymbolicAirBuilder { @@ -276,6 +277,7 @@ impl SymbolicAirBuilder { constraints: Vec::new(), bus_multiplicity_value: None, bus_data_values: None, + extra_declared_groups: Vec::new(), } } @@ -309,9 +311,10 @@ impl AirBuilder for SymbolicAirBuilder { if self.bus_multiplicity_value.is_none() { assert_eq!(values.len(), 1); self.bus_multiplicity_value = Some(values[0]); - } else { - assert!(self.bus_data_values.is_none()); + } else if self.bus_data_values.is_none() { self.bus_data_values = Some(values.to_vec()); + } else { + self.extra_declared_groups.push(values.to_vec()); } } } @@ -320,6 +323,7 @@ pub type SymbolicAirData = ( Vec>, SymbolicExpression, Vec>, + Vec>>, ); pub fn get_symbolic_constraints_and_bus_data_values(air: &A) -> SymbolicAirData @@ -332,5 +336,6 @@ where builder.constraints(), builder.bus_multiplicity_value.unwrap(), builder.bus_data_values.unwrap(), + builder.extra_declared_groups, ) } diff --git a/crates/backend/field/src/field.rs b/crates/backend/field/src/field.rs index f396cdda4..c3b005de2 100644 --- a/crates/backend/field/src/field.rs +++ b/crates/backend/field/src/field.rs @@ -384,6 +384,66 @@ pub trait PrimeCharacteristicRing: fn zero_vec(len: usize) -> Vec { vec![Self::ZERO; len] } + + /// Deferred multiply-accumulate protocol: unreduced product of two ring elements. + /// + /// The four slots of the `[Self; 4]` accumulator values used by `unreduced_mul` / + /// `lazy_acc_*` are an implementation-defined unreduced representation; only the + /// protocol methods of the SAME type may create or inspect them. The default + /// implementations are eager (slot 0 carries a canonical running value, slots 1-3 + /// stay `ZERO`), so for types that do not override the protocol it is exactly + /// equivalent to `acc + a * b` / `acc - a * b` chains — pure sugar, no behavioral + /// change. Overrides (e.g. 64-bit prime fields with widening multipliers) may keep + /// products unreduced across many accumulation steps and reduce once in + /// [`Self::lazy_acc_finish`]. + /// + /// # Contract + /// - At most `2^20` accumulation calls (each adding/subtracting one + /// [`Self::unreduced_mul`] product) between [`Self::lazy_acc_zero`] and + /// [`Self::lazy_acc_finish`]; callers chunk longer loops. + /// - `n_sub` passed to [`Self::lazy_acc_finish`] must equal the number of + /// [`Self::lazy_acc_sub`] calls since [`Self::lazy_acc_zero`]; implementations + /// whose subtraction is exact ignore it. + /// - Accumulator values must not be mixed across types or fed to ordinary + /// ring arithmetic. + #[must_use] + #[inline] + fn unreduced_mul(a: Self, b: Self) -> [Self; 4] { + [a * b, Self::ZERO, Self::ZERO, Self::ZERO] + } + + /// Fresh accumulator representing zero. See [`Self::unreduced_mul`] for the protocol. + #[must_use] + #[inline] + fn lazy_acc_zero() -> [Self; 4] { + [Self::ZERO; 4] + } + + /// Accumulate `+t` (a value produced by [`Self::unreduced_mul`]) into `acc`. + /// See [`Self::unreduced_mul`] for the protocol. + #[must_use] + #[inline] + fn lazy_acc_add(acc: [Self; 4], t: [Self; 4]) -> [Self; 4] { + [acc[0] + t[0], Self::ZERO, Self::ZERO, Self::ZERO] + } + + /// Accumulate `-t` (a value produced by [`Self::unreduced_mul`]) into `acc`. + /// See [`Self::unreduced_mul`] for the protocol. + #[must_use] + #[inline] + fn lazy_acc_sub(acc: [Self; 4], t: [Self; 4]) -> [Self; 4] { + [acc[0] - t[0], Self::ZERO, Self::ZERO, Self::ZERO] + } + + /// Collapse an accumulator to a canonical ring element. `n_sub` must equal the + /// number of [`Self::lazy_acc_sub`] calls on this accumulator (protocol docs at + /// [`Self::unreduced_mul`]). + #[must_use] + #[inline] + fn lazy_acc_finish(acc: [Self; 4], n_sub: u64) -> Self { + let _ = n_sub; + acc[0] + } } /// A vector space `V` over `F` with a fixed basis. Fixing the basis allows elements of `V` to be diff --git a/crates/backend/goldilocks/src/goldilocks.rs b/crates/backend/goldilocks/src/goldilocks.rs index 595aecee0..f794f981f 100644 --- a/crates/backend/goldilocks/src/goldilocks.rs +++ b/crates/backend/goldilocks/src/goldilocks.rs @@ -285,6 +285,74 @@ impl PrimeCharacteristicRing for Goldilocks { // SAFETY: `#[repr(transparent)]` means `Goldilocks` and `u64` share layout. unsafe { flatten_to_base(vec![0u64; len]) } } + + // Deferred multiply-accumulate protocol. + // + // Accumulator layout: 192-bit little-endian limbs in slots 0..3: + // slot 0 = l0 (bits 0..64), slot 1 = l1 (bits 64..128), slot 2 = l2 (bits 128..192), + // slot 3 unused. Seeded with OFFSET192 = P * 2^126 (a multiple of P, ~2^190) so + // mixed-sign accumulation never wraps the 192-bit domain: + // each term is < 2^128, the contract caps terms at 5 * 2^20 < 2^23, and + // 2^23 * 2^128 = 2^151 << 2^190 (headroom above: 2^192 - 2^190 = 3 * 2^190). + // The limb values are raw u64 bit patterns wrapped in `Goldilocks` (same trick as + // `zero_vec`'s transparent layout); they are never used as field elements. + // + // Same family as the `dot_product` override above (goldilocks.rs:248) and the + // poseidon1.rs:85-113 scalar MDS u128 accumulation (read-only precedent). + + #[inline] + fn unreduced_mul(a: Self, b: Self) -> [Self; 4] { + let wide = (a.value as u128) * (b.value as u128); + [ + Self::new(wide as u64), + Self::new((wide >> 64) as u64), + Self::ZERO, + Self::ZERO, + ] + } + + #[inline] + fn lazy_acc_zero() -> [Self; 4] { + // OFFSET192 = P << 126 = (P >> 2) * 2^128 + (P & 0b11 = 1) * 2^126: + // l0 = 0, l1 = 1 << 62, l2 = P >> 2. + [Self::ZERO, Self::new(1u64 << 62), Self::new(P >> 2), Self::ZERO] + } + + #[inline] + fn lazy_acc_add(acc: [Self; 4], t: [Self; 4]) -> [Self; 4] { + let (l0, c0) = acc[0].value.overflowing_add(t[0].value); + let (l1a, c1a) = acc[1].value.overflowing_add(t[1].value); + let (l1, c1b) = l1a.overflowing_add(c0 as u64); + let l2 = acc[2].value.wrapping_add(c1a as u64).wrapping_add(c1b as u64); + [Self::new(l0), Self::new(l1), Self::new(l2), Self::ZERO] + } + + #[inline] + fn lazy_acc_sub(acc: [Self; 4], t: [Self; 4]) -> [Self; 4] { + let (l0, b0) = acc[0].value.overflowing_sub(t[0].value); + let (l1a, b1a) = acc[1].value.overflowing_sub(t[1].value); + let (l1, b1b) = l1a.overflowing_sub(b0 as u64); + let l2 = acc[2].value.wrapping_sub(b1a as u64).wrapping_sub(b1b as u64); + [Self::new(l0), Self::new(l1), Self::new(l2), Self::ZERO] + } + + #[inline] + fn lazy_acc_finish(acc: [Self; 4], n_sub: u64) -> Self { + // Subtraction above is exact (borrow chains); no NOT-deficit to repay. + let _ = n_sub; + let l0 = acc[0].value; + let l1 = acc[1].value; + let l2 = acc[2].value; + // l2 must stay below 2^63 (deferred reduction invariant). + debug_assert!(l2 < (1u64 << 63), "lazy accumulator l2 overflow"); + // V = l2*2^128 + l1*2^64 + l0 ≡ l0 + l1*eps - l2*2^32 (mod P) + // (2^64 ≡ eps, 2^128 ≡ -2^32). Add P*2^33 (multiple of P, ~2^97 > l2*2^32 < 2^95) + // to keep the u128 arithmetic borrow-free, then one reduce128: + // r < 2^64 + 2^96 + 2^97 < 2^98. + const P_SHL33: u128 = (P as u128) << 33; + let r = (l0 as u128) + (l1 as u128) * (Self::NEG_ORDER as u128) + P_SHL33 - (l2 as u128) * (1u128 << 32); + reduce128(r) + } } /// `p - 1 = 2^32 * 3 * 5 * 17 * 257 * 65537`. The smallest `D` with `gcd(p - 1, D) = 1` is 7. diff --git a/crates/backend/goldilocks/src/x86_64_avx512/packing.rs b/crates/backend/goldilocks/src/x86_64_avx512/packing.rs index 93852147e..8a9aac6f4 100644 --- a/crates/backend/goldilocks/src/x86_64_avx512/packing.rs +++ b/crates/backend/goldilocks/src/x86_64_avx512/packing.rs @@ -126,6 +126,136 @@ impl PrimeCharacteristicRing for PackedGoldilocksAVX512 { // SAFETY: this is a repr(transparent) wrapper around an array. unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * WIDTH)) } } + + // Deferred multiply-accumulate protocol. + // + // Accumulator layout: [L, H, W, K], one zmm each: + // L = wrapping sum of term lows; K = count of L-wraps (each worth 2^64) + // H = wrapping sum of term highs; W = count of H-wraps (each worth 2^128) + // so the represented value is V = L + (H + K)*2^64 + W*2^128. + // + // Subtraction is NOT-based: adding (~t_hi, ~t_lo) contributes + // 2^128 - 1 - t = -t - (2^32 + 1) (mod P), since 2^128 = -2^32. Each sub + // therefore leaves a constant deficit of (2^32 + 1), repaid once at finish + // via `n_sub` (T1 protocol contract). + // + // The separate K counter is mandatory: folding the lo-carry into t_hi via a + // masked +1 is only safe for t_hi <= 2^64 - 2 (true for raw products) but NOT + // for NOT-ed terms, where ~t_hi = 2^64 - 1 whenever t_hi = 0 — and zero + // products are reachable in real eval tables + // cleverness). + // + // Finish: merge K into H (carry into W), then one folding step + // V = L + H'_lo*eps - (H'_hi + W*2^32) (mod P) + // using 2^64 = eps, 2^96 = -1, 2^128 = -2^32, with the same + // sub_no_double/add_no_double tail as `reduce128` below. Bounds: + // the T1 contract caps accumulation at 5*2^20 terms, so W, K < 2^23 and + // s = H'_hi + W*2^32 < 2^32 + 2^55 < P (tail precondition). + + #[inline] + fn unreduced_mul(a: Self, b: Self) -> [Self; 4] { + let (hi, lo) = mul64_64(a.to_vector(), b.to_vector()); + [Self::from_vector(lo), Self::from_vector(hi), Self::ZERO, Self::ZERO] + } + + #[inline] + fn lazy_acc_zero() -> [Self; 4] { + [Self::ZERO; 4] + } + + #[inline] + fn lazy_acc_add(acc: [Self; 4], t: [Self; 4]) -> [Self; 4] { + unsafe { + let (l, h, w, k) = ( + acc[0].to_vector(), + acc[1].to_vector(), + acc[2].to_vector(), + acc[3].to_vector(), + ); + let (t_lo, t_hi) = (t[0].to_vector(), t[1].to_vector()); + let one = _mm512_set1_epi64(1); + + let l2 = _mm512_add_epi64(l, t_lo); + let carry_l = _mm512_cmplt_epu64_mask(l2, t_lo); + let k2 = _mm512_mask_add_epi64(k, carry_l, k, one); + + let h2 = _mm512_add_epi64(h, t_hi); + let carry_h = _mm512_cmplt_epu64_mask(h2, t_hi); + let w2 = _mm512_mask_add_epi64(w, carry_h, w, one); + + [ + Self::from_vector(l2), + Self::from_vector(h2), + Self::from_vector(w2), + Self::from_vector(k2), + ] + } + } + + #[inline] + fn lazy_acc_sub(acc: [Self; 4], t: [Self; 4]) -> [Self; 4] { + unsafe { + // NOT both halves (vpternlogq imm 0x55 = NOT a), then add as usual. + let t_lo = t[0].to_vector(); + let t_hi = t[1].to_vector(); + let nt_lo = _mm512_ternarylogic_epi64::<0x55>(t_lo, t_lo, t_lo); + let nt_hi = _mm512_ternarylogic_epi64::<0x55>(t_hi, t_hi, t_hi); + Self::lazy_acc_add( + acc, + [ + Self::from_vector(nt_lo), + Self::from_vector(nt_hi), + Self::ZERO, + Self::ZERO, + ], + ) + } + } + + #[inline] + fn lazy_acc_finish(acc: [Self; 4], n_sub: u64) -> Self { + unsafe { + let (l, h, w, k) = ( + acc[0].to_vector(), + acc[1].to_vector(), + acc[2].to_vector(), + acc[3].to_vector(), + ); + let one = _mm512_set1_epi64(1); + + // Merge the lo-wrap counter into H, carrying into W. + let h2 = _mm512_add_epi64(h, k); + let carry = _mm512_cmplt_epu64_mask(h2, k); + let w2 = _mm512_mask_add_epi64(w, carry, w, one); + + #[cfg(debug_assertions)] + { + let w_arr: [u64; WIDTH] = transmute(w2); + let k_arr: [u64; WIDTH] = transmute(k); + for i in 0..WIDTH { + debug_assert!( + w_arr[i] < (1 << 23) && k_arr[i] < (1 << 23), + "lazy accumulator wrap counters overflow" + ); + } + } + + // One folding step: V = L + H'_lo*eps - (H'_hi + W*2^32) (mod P). + let e = _mm512_mul_epu32(h2, EPSILON); + let s = _mm512_add_epi64(_mm512_srli_epi64::<32>(h2), _mm512_slli_epi64::<32>(w2)); + let t0 = sub_no_double_overflow_64_64(l, s); + let r = add_no_double_overflow_64_64(t0, e); + let r = Self::from_vector(r); + + if n_sub == 0 { + r + } else { + // Repay the NOT deficit: n_sub * (2^32 + 1), n_sub < 2^23 so no overflow. + let deficit = Goldilocks::new(n_sub * ((1u64 << 32) + 1)); + r + Self::broadcast(deficit) + } + } + } } impl_add_base_field!(PackedGoldilocksAVX512, Goldilocks); @@ -390,9 +520,7 @@ unsafe fn mds_output(s: &[__m512i; 8], s_hi: &[__m512i; 8]) -> _ /// `vpmuludq`, while the `vpaddq` it replaces was happily dual-issuing on the /// add ports. Kept the `vpmuludq + vpaddq` form. #[inline(always)] -pub(crate) fn mds_mul_simd( - state: [PackedGoldilocksAVX512; POSEIDON1_WIDTH], -) -> [PackedGoldilocksAVX512; POSEIDON1_WIDTH] { +pub fn mds_mul_simd(state: [PackedGoldilocksAVX512; POSEIDON1_WIDTH]) -> [PackedGoldilocksAVX512; POSEIDON1_WIDTH] { unsafe { let s: [__m512i; 8] = [ state[0].to_vector(), diff --git a/crates/backend/goldilocks/tests/lazy_acc.rs b/crates/backend/goldilocks/tests/lazy_acc.rs new file mode 100644 index 000000000..a6b83bd82 --- /dev/null +++ b/crates/backend/goldilocks/tests/lazy_acc.rs @@ -0,0 +1,338 @@ +//! Tests for the deferred multiply-accumulate protocol. +//! scalar Goldilocks override (T2) and the eager default path. +//! +//! Oracle: plain canonical field arithmetic (`acc + a*b` / `acc - a*b` chains). +//! "Debug with code": failures print operand history and expected-vs-actual. + +use field::{BasedVectorSpace, PrimeCharacteristicRing, PrimeField64}; +use goldilocks::{CubicExtensionFieldGL, Goldilocks, P}; + +/// Deterministic xorshift64* PRNG (no new deps; seeded for reproducibility). +struct XorShift(u64); +impl XorShift { + fn next(&mut self) -> u64 { + let mut x = self.0; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.0 = x; + x.wrapping_mul(0x2545_F491_4F6C_DD1D) + } +} + +/// Boundary operand set (includes non-canonical representatives). +const EPS: u64 = 0xFFFF_FFFF; // 2^32 - 1 = NEG_ORDER +fn boundary_vals() -> Vec { + vec![ + 0, + 1, + EPS, + P - 1, + P, + P + 1, + 1u64 << 32, + (1u64 << 33) - 1, + 1u64 << 63, + u64::MAX, + ] +} + +/// Eager oracle: canonical accumulate of signed products. +fn eager_chain(terms: &[(u64, u64, bool)]) -> Goldilocks { + let mut acc = Goldilocks::ZERO; + for &(a, b, is_sub) in terms { + let prod = Goldilocks::new(a) * Goldilocks::new(b); + if is_sub { + acc -= prod; + } else { + acc += prod; + } + } + acc +} + +/// Lazy path under test. +fn lazy_chain(terms: &[(u64, u64, bool)]) -> Goldilocks { + let mut acc = Goldilocks::lazy_acc_zero(); + let mut n_sub = 0u64; + for &(a, b, is_sub) in terms { + let t = Goldilocks::unreduced_mul(Goldilocks::new(a), Goldilocks::new(b)); + if is_sub { + acc = Goldilocks::lazy_acc_sub(acc, t); + n_sub += 1; + } else { + acc = Goldilocks::lazy_acc_add(acc, t); + } + } + Goldilocks::lazy_acc_finish(acc, n_sub) +} + +fn assert_chains_equal(terms: &[(u64, u64, bool)], ctx: &str) { + let expected = eager_chain(terms); + let actual = lazy_chain(terms); + assert_eq!( + expected.as_canonical_u64(), + actual.as_canonical_u64(), + "{ctx}: lazy != eager.\n terms (a, b, is_sub) = {terms:?}\n expected {} actual {}", + expected.as_canonical_u64(), + actual.as_canonical_u64() + ); +} + +#[test] +fn scalar_boundary_pairs_add_sub_mixed() { + let vals = boundary_vals(); + for &a in &vals { + for &b in &vals { + assert_chains_equal(&[(a, b, false)], "single add"); + assert_chains_equal(&[(a, b, true)], "single sub"); + for &c in &vals { + // add then sub, sub then add — exercises mixed-sign limb traffic. + assert_chains_equal(&[(a, b, false), (b, c, true)], "add+sub pair"); + assert_chains_equal(&[(a, b, true), (b, c, false)], "sub+add pair"); + assert_chains_equal(&[(a, b, true), (b, c, true)], "sub+sub pair"); + } + } + } +} + +#[test] +fn scalar_randomized_equivalence_10m() { + // 10^7 randomized scalar iterations against the eager oracle. + // Checked in chunks of <= 5 terms (the protocol's per-iteration term shape). + let mut rng = XorShift(0x5EED_0001); + let mut total = 0usize; + while total < 10_000_000 { + let len = 1 + (rng.next() % 5) as usize; + let terms: Vec<(u64, u64, bool)> = (0..len) + .map(|_| (rng.next(), rng.next(), rng.next() & 1 == 1)) + .collect(); + assert_chains_equal(&terms, "randomized chunk"); + total += len; + } +} + +#[test] +fn scalar_mixed_walks_with_prefix_checks() { + // mixed add/sub random walks, lengths 1..8192, checking prefixes. + let mut rng = XorShift(0x5EED_0002); + for walk in 0..64 { + let len = 1 + (rng.next() % 8192) as usize; + let terms: Vec<(u64, u64, bool)> = (0..len) + .map(|_| (rng.next(), rng.next(), rng.next() & 1 == 1)) + .collect(); + // check every prefix for short walks, end-only for long ones + if len <= 256 { + for p in 1..=len { + assert_chains_equal(&terms[..p], &format!("walk {walk} prefix {p}")); + } + } else { + assert_chains_equal(&terms, &format!("walk {walk} full")); + } + } +} + +#[test] +fn scalar_worst_case_counter_growth() { + // 2^20 iterations x 5 max-magnitude terms; exercises the overflow + // l2-drift bound (debug_assert in finish fires if violated) and equivalence. + let max = u64::MAX; + let t = Goldilocks::unreduced_mul(Goldilocks::new(max), Goldilocks::new(max)); + let n: u64 = 5 * (1 << 20); + + // all-adds + let mut acc = Goldilocks::lazy_acc_zero(); + for _ in 0..n { + acc = Goldilocks::lazy_acc_add(acc, t); + } + let lazy_adds = Goldilocks::lazy_acc_finish(acc, 0); + let prod = Goldilocks::new(max) * Goldilocks::new(max); + let eager_adds = prod * Goldilocks::new(n); + assert_eq!(lazy_adds.as_canonical_u64(), eager_adds.as_canonical_u64()); + + // max-sub pattern + let mut acc = Goldilocks::lazy_acc_zero(); + for _ in 0..n { + acc = Goldilocks::lazy_acc_sub(acc, t); + } + let lazy_subs = Goldilocks::lazy_acc_finish(acc, n); + let eager_subs = -eager_adds; + assert_eq!(lazy_subs.as_canonical_u64(), eager_subs.as_canonical_u64()); + + // alternating (worst mixed-sign limb churn around the OFFSET192 seed) + let mut acc = Goldilocks::lazy_acc_zero(); + let mut n_sub = 0; + for i in 0..n { + if i & 1 == 0 { + acc = Goldilocks::lazy_acc_add(acc, t); + } else { + acc = Goldilocks::lazy_acc_sub(acc, t); + n_sub += 1; + } + } + let lazy_alt = Goldilocks::lazy_acc_finish(acc, n_sub); + let eager_alt = prod * Goldilocks::new(n / 2 + (n & 1)) - prod * Goldilocks::new(n / 2); + assert_eq!(lazy_alt.as_canonical_u64(), eager_alt.as_canonical_u64()); +} + +#[test] +fn scalar_zero_products_interleaved() { + // Zero products inside sub patterns (relevant to the packed ~t_hi edge in T3; + // pinned here for the scalar path too). + let mut rng = XorShift(0x5EED_0003); + for _ in 0..1000 { + let mut terms: Vec<(u64, u64, bool)> = (0..16).map(|_| (rng.next(), rng.next(), rng.next() & 1 == 1)).collect(); + terms[3] = (0, 0, true); + terms[7] = (0, rng.next(), true); + terms[11] = (rng.next(), 0, false); + assert_chains_equal(&terms, "zero-product interleave"); + } +} + +/// T3: packed AVX-512 override — lane-by-lane equivalence against the scalar +/// T2 override (itself pinned against the eager oracle above). +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +mod packed { + use super::*; + use goldilocks::PackedGoldilocksAVX512 as PG; + + const W: usize = 8; + + /// Run the same (a, b, is_sub) chain on all 8 lanes (lane i gets chain i) + /// through the packed protocol; compare each lane's finish vs scalar lazy. + #[allow(clippy::needless_range_loop)] // cross-indexing chains[i][j] reads clearer as index loops + fn assert_packed_matches_scalar(chains: &[Vec<(u64, u64, bool)>; W], ctx: &str) { + let len = chains[0].len(); + assert!(chains.iter().all(|c| c.len() == len)); + let mut acc = PG::lazy_acc_zero(); + let mut n_sub = 0u64; + for j in 0..len { + let a = PG(core::array::from_fn(|i| Goldilocks::new(chains[i][j].0))); + let b = PG(core::array::from_fn(|i| Goldilocks::new(chains[i][j].1))); + let t = PG::unreduced_mul(a, b); + // protocol restriction: one op per step across all lanes — callers + // (the T4 kernel) always do uniform ops, so chains share is_sub at j. + if chains[0][j].2 { + acc = PG::lazy_acc_sub(acc, t); + n_sub += 1; + } else { + acc = PG::lazy_acc_add(acc, t); + } + } + let packed = PG::lazy_acc_finish(acc, n_sub); + for i in 0..W { + let expected = lazy_chain(&chains[i]); + assert_eq!( + expected.as_canonical_u64(), + packed.0[i].as_canonical_u64(), + "{ctx}: lane {i} mismatch.\n chain = {:?}\n expected {} actual {}", + chains[i], + expected.as_canonical_u64(), + packed.0[i].as_canonical_u64() + ); + } + } + + fn chains_from(rng: &mut XorShift, len: usize, sub_pattern: impl Fn(usize) -> bool) -> [Vec<(u64, u64, bool)>; W] { + core::array::from_fn(|_| (0..len).map(|j| (rng.next(), rng.next(), sub_pattern(j))).collect()) + } + + #[test] + fn packed_boundary_lanes() { + // Pack the boundary set across lanes; all add, all sub, alternating. + let vals = boundary_vals(); + for &b in &vals { + for is_sub in [false, true] { + let chains: [Vec<(u64, u64, bool)>; W] = + core::array::from_fn(|i| vec![(vals[i % vals.len()], b, is_sub)]); + assert_packed_matches_scalar(&chains, "boundary single"); + } + } + } + + #[test] + fn packed_randomized_walks() { + let mut rng = XorShift(0x5EED_1001); + for walk in 0..200 { + let len = 1 + (rng.next() % 512) as usize; + // uniform sub pattern per step (kernel-shaped); random pattern seed + let pat = rng.next(); + let chains = chains_from(&mut rng, len, |j| (pat >> (j % 64)) & 1 == 1); + assert_packed_matches_scalar(&chains, &format!("walk {walk}")); + } + } + + #[test] + fn packed_zero_products_in_sub_patterns() { + // The ~t_hi = 2^64-1 edge: zero (and small) products inside subs — + // exercises the mandatory separate-K-counter design. + let mut rng = XorShift(0x5EED_1002); + for _ in 0..200 { + let mut chains = chains_from(&mut rng, 32, |j| j % 3 == 1); + for c in chains.iter_mut() { + c[1] = (0, 0, true); + c[4] = (0, rng.next(), true); + c[7] = (rng.next(), 0, true); + c[10] = (1, 1, true); + } + assert_packed_matches_scalar(&chains, "zero-product subs"); + } + } + + #[test] + fn packed_worst_case_counter_growth() { + // 5*2^20 max-magnitude terms, all-add then + // max-sub pattern; equivalence against the scalar override result. + let n: u64 = 5 * (1 << 20); + let max = PG(core::array::from_fn(|_| Goldilocks::new(u64::MAX))); + let t = PG::unreduced_mul(max, max); + + let mut acc = PG::lazy_acc_zero(); + for _ in 0..n { + acc = PG::lazy_acc_add(acc, t); + } + let packed_adds = PG::lazy_acc_finish(acc, 0); + + let mut acc = PG::lazy_acc_zero(); + for _ in 0..n { + acc = PG::lazy_acc_sub(acc, t); + } + let packed_subs = PG::lazy_acc_finish(acc, n); + + let prod = Goldilocks::new(u64::MAX) * Goldilocks::new(u64::MAX); + let eager_adds = prod * Goldilocks::new(n); + for i in 0..W { + assert_eq!(packed_adds.0[i].as_canonical_u64(), eager_adds.as_canonical_u64()); + assert_eq!(packed_subs.0[i].as_canonical_u64(), (-eager_adds).as_canonical_u64()); + } + } +} + +#[test] +fn default_path_is_pure_sugar_cubic_extension() { + // CubicExtensionFieldGL does NOT override the protocol -> exercises the + // eager defaults; pins that the trait surface is value-preserving sugar. + type EF = CubicExtensionFieldGL; + let mut rng = XorShift(0x5EED_0004); + let rand_ef = |rng: &mut XorShift| -> EF { EF::from_basis_coefficients_fn(|_| Goldilocks::new(rng.next())) }; + for _ in 0..1000 { + let terms: Vec<(EF, EF, bool)> = (0..8) + .map(|_| (rand_ef(&mut rng), rand_ef(&mut rng), rng.next() & 1 == 1)) + .collect(); + let mut eager = EF::ZERO; + let mut acc = EF::lazy_acc_zero(); + let mut n_sub = 0; + for &(a, b, s) in &terms { + let t = EF::unreduced_mul(a, b); + if s { + eager -= a * b; + acc = EF::lazy_acc_sub(acc, t); + n_sub += 1; + } else { + eager += a * b; + acc = EF::lazy_acc_add(acc, t); + } + } + assert_eq!(eager, EF::lazy_acc_finish(acc, n_sub)); + } +} diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index df1248c81..9e1bf31fd 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -43,14 +43,20 @@ pub fn run_product_sumcheck>>( pow_bits: usize, ) -> (MultilinearPoint, EF, MleOwned, MleOwned) { assert!(n_rounds >= 1); - let first_sumcheck_poly = match (pol_a, pol_b) { - (MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => { - if EF::DIMENSION == 3 { - compute_product_sumcheck_polynomial_base_ext_packed::<3, _, _, _, EF>(evals, weights, sum) - } else { - unimplemented!() + if let (MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) = (pol_a, pol_b) { + if EF::DIMENSION == 3 { + // Lazy path for large instances only; the eager path stays the + // fallback (and keeps bytecode_claims' small-domain call site on + // its existing code). + let n_vars = (evals.len() * PFPacking::::WIDTH).trailing_zeros() as usize; + if n_vars >= 18 && n_rounds >= 2 { + return run_product_sumcheck_base_lazy::<3, EF>(evals, weights, prover_state, sum, n_rounds, pow_bits); } + return run_product_sumcheck_base_eager::<3, EF>(evals, weights, prover_state, sum, n_rounds, pow_bits); } + unimplemented!() + } + let first_sumcheck_poly = match (pol_a, pol_b) { (MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => { compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::::to_ext_iter([e]).collect()) } @@ -73,13 +79,6 @@ pub fn run_product_sumcheck>>( } let (second_sumcheck_poly, folded) = match (pol_a, pol_b) { - (MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => { - let (second_sumcheck_poly, folded) = - fold_and_compute_product_sumcheck_polynomial(evals, weights, r1, sum, |e| { - EFPacking::::to_ext_iter([e]).collect() - }); - (second_sumcheck_poly, MleGroupOwned::ExtensionPacked(folded)) - } (MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => { let (second_sumcheck_poly, folded) = fold_and_compute_product_sumcheck_polynomial(evals, weights, r1, sum, |e| { @@ -124,6 +123,371 @@ pub fn run_product_sumcheck>>( (challenges, sum, pol_a, pol_b) } +/// Eager path for the (BasePacked, ExtensionPacked) arm, extracted verbatim from +/// `run_product_sumcheck`. Kept as the equality oracle for the lazy path and as +/// the fallback for small instances. Generic over `DIM` so the full path is +/// exercisable by the KoalaBear (DIM = 5) test harness. +pub fn run_product_sumcheck_base_eager>>( + evals: &[PFPacking], + weights: &[EFPacking], + prover_state: &mut impl FSProver, + mut sum: EF, + n_rounds: usize, + pow_bits: usize, +) -> (MultilinearPoint, EF, MleOwned, MleOwned) { + assert!(n_rounds >= 1); + let first_sumcheck_poly = + compute_product_sumcheck_polynomial_base_ext_packed::(evals, weights, sum); + + prover_state.add_sumcheck_polynomial(&first_sumcheck_poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r1: EF = prover_state.sample(); + sum = first_sumcheck_poly.evaluate(r1); + + if n_rounds == 1 { + return ( + MultilinearPoint(vec![r1]), + sum, + MleRef::::BasePacked(evals).fold(r1), + MleRef::::ExtensionPacked(weights).fold(r1), + ); + } + + let (second_sumcheck_poly, folded) = { + let (second_sumcheck_poly, folded) = + fold_and_compute_product_sumcheck_polynomial(evals, weights, r1, sum, |e| { + EFPacking::::to_ext_iter([e]).collect() + }); + (second_sumcheck_poly, MleGroupOwned::ExtensionPacked(folded)) + }; + + prover_state.add_sumcheck_polynomial(&second_sumcheck_poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r2: EF = prover_state.sample(); + sum = second_sumcheck_poly.evaluate(r2); + + let (mut challenges, folds, sum) = sumcheck_prove_many_rounds( + folded, + Some(r2), + &ProductComputation {}, + &vec![], + None, + prover_state, + sum, + None, + n_rounds - 2, + false, + pow_bits, + ); + + challenges.splice(0..0, [r1, r2]); + let [pol_a, pol_b] = folds.split().try_into().unwrap(); + (challenges, sum, pol_a, pol_b) +} + +/// Lazy-fold path for the (BasePacked, ExtensionPacked) arm. +/// +/// Round 1 never materializes the folded eval array: evals stay as two implicit +/// base streams `a_i = P[i]`, `b_i = P[half+i] - P[i]` (the fold by `r1` is kept +/// challenge-symbolic and bound once per accumulator), while the weights are +/// folded eagerly in the same pass. Round 2 entry then materializes the +/// double-folded evals directly from `P` via the 4-term basis comb +/// `E''[i] = a_i + r1 b_i + r2 da_i + r1 r2 db_i` fused with the W''-fold and the +/// round-2 quadratic. All reorderings are exact field identities, so the +/// transcript is bit-identical to the eager path (enforced by the equality +/// tests below and the e2e proof byte-diff gate). +pub fn run_product_sumcheck_base_lazy>>( + evals: &[PFPacking], + weights: &[EFPacking], + prover_state: &mut impl FSProver, + mut sum: EF, + n_rounds: usize, + pow_bits: usize, +) -> (MultilinearPoint, EF, MleOwned, MleOwned) { + assert!(n_rounds >= 2); + let first_sumcheck_poly = + compute_product_sumcheck_polynomial_base_ext_packed::(evals, weights, sum); + + prover_state.add_sumcheck_polynomial(&first_sumcheck_poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r1: EF = prover_state.sample(); + sum = first_sumcheck_poly.evaluate(r1); + + let (second_sumcheck_poly, w_folded) = + fold_weights_and_compute_lazy_round1::(evals, weights, r1, sum); + + prover_state.add_sumcheck_polynomial(&second_sumcheck_poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r2: EF = prover_state.sample(); + sum = second_sumcheck_poly.evaluate(r2); + + let decompose = |e| EFPacking::::to_ext_iter([e]).collect::>(); + + if n_rounds == 2 { + let (_, e2, w2) = materialize_double_fold_and_compute_round2::( + evals, &w_folded, r1, r2, sum, false, decompose, + ); + return ( + MultilinearPoint(vec![r1, r2]), + sum, + MleOwned::ExtensionPacked(e2), + MleOwned::ExtensionPacked(w2), + ); + } + + let (third_sumcheck_poly, e2, w2) = + materialize_double_fold_and_compute_round2::(evals, &w_folded, r1, r2, sum, true, decompose); + let third_sumcheck_poly = third_sumcheck_poly.unwrap(); + + prover_state.add_sumcheck_polynomial(&third_sumcheck_poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r3: EF = prover_state.sample(); + sum = third_sumcheck_poly.evaluate(r3); + + let (mut challenges, folds, sum) = sumcheck_prove_many_rounds( + MleGroupOwned::ExtensionPacked(vec![e2, w2]), + Some(r3), + &ProductComputation {}, + &vec![], + None, + prover_state, + sum, + None, + n_rounds - 3, + false, + pow_bits, + ); + + challenges.splice(0..0, [r1, r2, r3]); + let [pol_a, pol_b] = folds.split().try_into().unwrap(); + (challenges, sum, pol_a, pol_b) +} + +/// Round-1 kernel of the lazy path: folds the weights by `r1` (the only array +/// that must be materialized) and computes the round-1 quadratic with the eval +/// fold kept challenge-symbolic: writing `E'[j] = a_j + r1 b_j` (`a_j = P[j]`, +/// `b_j = P[half+j] - P[j]`, both base field) and pairing `(j, q+j)`, +/// +/// c0 = sum_j E'[j] W'[j] = S_c0a + r1 S_c0b +/// c2 = sum_j (E'[q+j]-E'[j])(W'[q+j]-W'[j]) = S_c2a + r1 S_c2b +/// +/// with the four S-streams pure base-x-ext lazy-accumulator sums; `r1` is bound +/// once per stream after the horizontal lane sum. Per L2-resident chunk the +/// work is five passes (W'-fold + one per stream), each holding DIM lazy +/// accumulators (3 groups x 4 zmm — the register budget proven by the round-0 +/// kernel above). +#[allow(clippy::needless_range_loop)] +pub fn fold_weights_and_compute_lazy_round1< + const DIM: usize, + F: PrimeField64, + PF: PackedField, + EFP: BasedVectorSpace + Algebra + From + Copy + Send + Sync + 'static, + EF: Field + BasedVectorSpace, +>( + pol_0: &[PF], + pol_1: &[EFP], + r1: EF, + sum: EF, +) -> (DensePolynomial, ArenaVec) { + assert_eq!(DIM, EF::DIMENSION); + let n = pol_0.len(); + assert_eq!(n, pol_1.len()); + assert!(n.is_power_of_two()); + assert!(n >= 4); + let half = n / 2; + let q = half / 2; + let r1_packed = EFP::from(r1); + + let mut w_folded = unsafe { ArenaVec::::uninitialized(half) }; + let wf_ptr = parallel::SendPtr(w_folded.as_mut_ptr()); + + let chunk_size = 1024; + let n_chunks = q.div_ceil(chunk_size); + + // Four unreduced streams (n_sub = 0: every term is a raw positive product + // of canonical elements; chunk length 1024 <= 2^20 honors the protocol + // bound, see the round-0 kernel comment). + type Streams = ([PF; DIM], [PF; DIM], [PF; DIM], [PF; DIM]); + let (c0a_acc, c0b_acc, c2a_acc, c2b_acc): Streams = parallel::map_reduce( + n_chunks, + || ([PF::ZERO; DIM], [PF::ZERO; DIM], [PF::ZERO; DIM], [PF::ZERO; DIM]), + |chunk| { + let start = chunk * chunk_size; + let end = (start + chunk_size).min(q); + + // Pass A: fold the weights for both halves of the pair range and + // store into W' (the only materialized array). + for i in start..end { + let w_lo = r1_packed * (pol_1[half + i] - pol_1[i]) + pol_1[i]; + let w_hi = r1_packed * (pol_1[half + q + i] - pol_1[q + i]) + pol_1[q + i]; + unsafe { + *wf_ptr.add(i) = w_lo; + *wf_ptr.add(q + i) = w_hi; + } + } + let w_lo_chunk: &[EFP] = unsafe { core::slice::from_raw_parts(wf_ptr.add(start), end - start) }; + let w_hi_chunk: &[EFP] = unsafe { core::slice::from_raw_parts(wf_ptr.add(q + start), end - start) }; + + // Pass B: S_c0a += a_lo * w_lo. + let mut c0a: [[PF; 4]; DIM] = core::array::from_fn(|_| PF::lazy_acc_zero()); + for i in 0..(end - start) { + let a = pol_0[start + i]; + let w = w_lo_chunk[i].as_basis_coefficients_slice(); + for j in 0..DIM { + c0a[j] = PF::lazy_acc_add(c0a[j], PF::unreduced_mul(w[j], a)); + } + } + // Pass C: S_c0b += b_lo * w_lo. + let mut c0b: [[PF; 4]; DIM] = core::array::from_fn(|_| PF::lazy_acc_zero()); + for i in 0..(end - start) { + let b = pol_0[half + start + i] - pol_0[start + i]; + let w = w_lo_chunk[i].as_basis_coefficients_slice(); + for j in 0..DIM { + c0b[j] = PF::lazy_acc_add(c0b[j], PF::unreduced_mul(w[j], b)); + } + } + // Pass D1: S_c2a += (a_hi - a_lo) * (w_hi - w_lo). + let mut c2a: [[PF; 4]; DIM] = core::array::from_fn(|_| PF::lazy_acc_zero()); + for i in 0..(end - start) { + let da = pol_0[q + start + i] - pol_0[start + i]; + let dw = w_hi_chunk[i] - w_lo_chunk[i]; + let dw_coords = dw.as_basis_coefficients_slice(); + for j in 0..DIM { + c2a[j] = PF::lazy_acc_add(c2a[j], PF::unreduced_mul(dw_coords[j], da)); + } + } + // Pass D2: S_c2b += (b_hi - b_lo) * (w_hi - w_lo). + let mut c2b: [[PF; 4]; DIM] = core::array::from_fn(|_| PF::lazy_acc_zero()); + for i in 0..(end - start) { + let b_lo = pol_0[half + start + i] - pol_0[start + i]; + let b_hi = pol_0[half + q + start + i] - pol_0[q + start + i]; + let db = b_hi - b_lo; + let dw = w_hi_chunk[i] - w_lo_chunk[i]; + let dw_coords = dw.as_basis_coefficients_slice(); + for j in 0..DIM { + c2b[j] = PF::lazy_acc_add(c2b[j], PF::unreduced_mul(dw_coords[j], db)); + } + } + ( + core::array::from_fn(|j| PF::lazy_acc_finish(c0a[j], 0)), + core::array::from_fn(|j| PF::lazy_acc_finish(c0b[j], 0)), + core::array::from_fn(|j| PF::lazy_acc_finish(c2a[j], 0)), + core::array::from_fn(|j| PF::lazy_acc_finish(c2b[j], 0)), + ) + }, + |(mut a0, mut a1, mut a2, mut a3): Streams, (b0, b1, b2, b3): Streams| { + for j in 0..DIM { + a0[j] += b0[j]; + a1[j] += b1[j]; + a2[j] += b2[j]; + a3[j] += b3[j]; + } + (a0, a1, a2, a3) + }, + ); + + let lane_sum = |p: PF| { + let mut s = F::ZERO; + for &v in p.as_slice() { + s += v; + } + s + }; + let s_c0a = EF::from_basis_coefficients_fn(|j| lane_sum(c0a_acc[j])); + let s_c0b = EF::from_basis_coefficients_fn(|j| lane_sum(c0b_acc[j])); + let s_c2a = EF::from_basis_coefficients_fn(|j| lane_sum(c2a_acc[j])); + let s_c2b = EF::from_basis_coefficients_fn(|j| lane_sum(c2b_acc[j])); + + // Bind r1 once per stream. + let c0 = s_c0a + r1 * s_c0b; + let c2 = s_c2a + r1 * s_c2b; + let c1 = sum - c0.double() - c2; + + (DensePolynomial::new(vec![c0, c1, c2]), w_folded) +} + +/// Round-2 entry of the lazy path: materializes the double-folded evals +/// directly from the original base array via the 4-term basis comb, folds the +/// weights by `r2`, and (unless `compute_quad` is false, i.e. `n_rounds == 2`) +/// computes the round-2 quadratic — all in one parallel pass. +#[allow(clippy::too_many_arguments)] +pub fn materialize_double_fold_and_compute_round2< + const DIM: usize, + F: PrimeField64, + PF: PackedField, + EFP: BasedVectorSpace + Algebra + From + Copy + Send + Sync + 'static, + EF: Field + BasedVectorSpace, +>( + pol_0: &[PF], + w_folded: &[EFP], + r1: EF, + r2: EF, + sum: EF, + compute_quad: bool, + decompose: impl Fn(EFP) -> Vec, +) -> (Option>, ArenaVec, ArenaVec) { + assert_eq!(DIM, EF::DIMENSION); + let n = pol_0.len(); + assert!(n.is_power_of_two()); + assert!(n >= 8); + let half = n / 2; + let q = half / 2; + assert_eq!(w_folded.len(), half); + let e = q / 2; + + let r1_packed = EFP::from(r1); + let r2_packed = EFP::from(r2); + let r12_packed = EFP::from(r1 * r2); + + let mut e2 = unsafe { ArenaVec::::uninitialized(q) }; + let mut w2 = unsafe { ArenaVec::::uninitialized(q) }; + let e2_ptr = parallel::SendPtr(e2.as_mut_ptr()); + let w2_ptr = parallel::SendPtr(w2.as_mut_ptr()); + + // E''[j] = a_j + r1 b_j + r2 da_j + r1 r2 db_j, from P only. + let comb = |j: usize| -> EFP { + let a = pol_0[j]; + let b = pol_0[half + j] - pol_0[j]; + let da = pol_0[q + j] - pol_0[j]; + let db = (pol_0[half + q + j] - pol_0[q + j]) - b; + r1_packed * b + (r2_packed * da + (r12_packed * db + a)) + }; + + let (c0_packed, c2_packed) = parallel::map_reduce( + e, + || (EFP::ZERO, EFP::ZERO), + |i| { + let e_lo = comb(i); + let e_hi = comb(e + i); + let w_lo = r2_packed * (w_folded[q + i] - w_folded[i]) + w_folded[i]; + let w_hi = r2_packed * (w_folded[q + e + i] - w_folded[e + i]) + w_folded[e + i]; + unsafe { + *e2_ptr.add(i) = e_lo; + *e2_ptr.add(e + i) = e_hi; + *w2_ptr.add(i) = w_lo; + *w2_ptr.add(e + i) = w_hi; + } + if compute_quad { + (w_lo * e_lo, (w_hi - w_lo) * (e_hi - e_lo)) + } else { + (EFP::ZERO, EFP::ZERO) + } + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ); + + let poly = if compute_quad { + let c0 = decompose(c0_packed).into_iter().sum::(); + let c2 = decompose(c2_packed).into_iter().sum::(); + let c1 = sum - c0.double() - c2; + Some(DensePolynomial::new(vec![c0, c1, c2])) + } else { + None + }; + + (poly, e2, w2) +} + pub fn compute_product_sumcheck_polynomial< F: PrimeCharacteristicRing + Copy + Send + Sync, EF: Field, @@ -189,9 +553,14 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< let chunk_size = 1024; let n_chunks = half.div_ceil(chunk_size); + // Deferred-reduction packed accumulation (lazy-accumulator protocol, see + // PrimeCharacteristicRing::unreduced_mul): each chunk keeps 2*DIM unreduced + // accumulators and reduces once at chunk exit. All terms are raw products + // added positively, so n_sub = 0. Chunk length 1024 <= 2^20 honors the + // protocol bound (1 term per accumulator per iteration). let (c0_acc, c2_acc) = parallel::map_reduce( n_chunks, - || ([F::ZERO; DIM], [F::ZERO; DIM]), + || ([PF::ZERO; DIM], [PF::ZERO; DIM]), |chunk| { let start = chunk * chunk_size; let end = (start + chunk_size).min(half); @@ -199,29 +568,31 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< let b_hi = &pol_0[half + start..half + end]; let e_lo = &pol_1[start..end]; let e_hi = &pol_1[half + start..half + end]; - let mut c0 = [F::ZERO; DIM]; - let mut c2 = [F::ZERO; DIM]; + // Two passes: 6 accumulators x 4 zmm = 24 spills; split to 3 x 4 = 12. + let mut c0_lazy: [[PF; 4]; DIM] = core::array::from_fn(|_| PF::lazy_acc_zero()); + for i in 0..b_lo.len() { + let x0 = b_lo[i]; + let y0_coords = e_lo[i].as_basis_coefficients_slice(); + for j in 0..DIM { + c0_lazy[j] = PF::lazy_acc_add(c0_lazy[j], PF::unreduced_mul(y0_coords[j], x0)); + } + } + let mut c2_lazy: [[PF; 4]; DIM] = core::array::from_fn(|_| PF::lazy_acc_zero()); for i in 0..b_lo.len() { - let x0_lanes = b_lo[i].as_slice(); - let x1_lanes = b_hi[i].as_slice(); + let dx = b_hi[i] - b_lo[i]; let y0_coords = e_lo[i].as_basis_coefficients_slice(); let y1_coords = e_hi[i].as_basis_coefficients_slice(); for j in 0..DIM { - let y0_j = y0_coords[j].as_slice(); - let y1_j = y1_coords[j].as_slice(); - for lane in 0..PF::WIDTH { - let x0 = x0_lanes[lane]; - let x1 = x1_lanes[lane]; - let y0 = y0_j[lane]; - let y1 = y1_j[lane]; - c0[j] += y0 * x0; - c2[j] += (y1 - y0) * (x1 - x0); - } + let dy = y1_coords[j] - y0_coords[j]; + c2_lazy[j] = PF::lazy_acc_add(c2_lazy[j], PF::unreduced_mul(dy, dx)); } } - (c0, c2) + ( + core::array::from_fn(|j| PF::lazy_acc_finish(c0_lazy[j], 0)), + core::array::from_fn(|j| PF::lazy_acc_finish(c2_lazy[j], 0)), + ) }, - |(mut a0, mut a2): ([F; DIM], [F; DIM]), (b0, b2): ([F; DIM], [F; DIM])| { + |(mut a0, mut a2): ([PF; DIM], [PF; DIM]), (b0, b2): ([PF; DIM], [PF; DIM])| { for j in 0..DIM { a0[j] += b0[j]; a2[j] += b2[j]; @@ -230,8 +601,16 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< }, ); - let c0 = EF::from_basis_coefficients_fn(|j| c0_acc[j]); - let c2 = EF::from_basis_coefficients_fn(|j| c2_acc[j]); + // Horizontal lane sum once at the very end. + let lane_sum = |p: PF| { + let mut s = F::ZERO; + for &v in p.as_slice() { + s += v; + } + s + }; + let c0 = EF::from_basis_coefficients_fn(|j| lane_sum(c0_acc[j])); + let c2 = EF::from_basis_coefficients_fn(|j| lane_sum(c2_acc[j])); let c1 = sum - c0.double() - c2; DensePolynomial::new(vec![c0, c1, c2]) diff --git a/crates/backend/sumcheck/tests/product_sumcheck_equality.rs b/crates/backend/sumcheck/tests/product_sumcheck_equality.rs new file mode 100644 index 000000000..b4781606c --- /dev/null +++ b/crates/backend/sumcheck/tests/product_sumcheck_equality.rs @@ -0,0 +1,376 @@ +use field::*; +use fiat_shamir::*; +use koala_bear::{KoalaBear, PackedQuinticExtensionFieldKB, QuinticExtensionFieldKB}; +use poly::*; +use sumcheck::*; + +type EFKb = QuinticExtensionFieldKB; +type PFKb = ::Packing; +const DIM_KB: usize = 5; + +struct XorShift(u64); +impl XorShift { + fn next_u64(&mut self) -> u64 { + let mut x = self.0; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.0 = x; + x + } +} + +// ---- base_ext_packed kernel tests ---- + +fn reference_base_ext_packed< + const DIM: usize, + F: PrimeField64, + PF: PackedField, + EFP: BasedVectorSpace + Copy + Send + Sync, + EF: Field + BasedVectorSpace, +>( + pol_0: &[PF], + pol_1: &[EFP], + sum: EF, +) -> DensePolynomial { + assert_eq!(DIM, EF::DIMENSION); + let n = pol_0.len(); + assert_eq!(n, pol_1.len()); + assert!(n.is_power_of_two()); + let half = n / 2; + let mut c0_acc = [F::ZERO; DIM]; + let mut c2_acc = [F::ZERO; DIM]; + for i in 0..half { + let x0_lanes = pol_0[i].as_slice(); + let x1_lanes = pol_0[half + i].as_slice(); + let y0_coords = pol_1[i].as_basis_coefficients_slice(); + let y1_coords = pol_1[half + i].as_basis_coefficients_slice(); + for j in 0..DIM { + let y0_j = y0_coords[j].as_slice(); + let y1_j = y1_coords[j].as_slice(); + for lane in 0..PF::WIDTH { + let x0 = x0_lanes[lane]; + let x1 = x1_lanes[lane]; + let y0 = y0_j[lane]; + let y1 = y1_j[lane]; + c0_acc[j] += y0 * x0; + c2_acc[j] += (y1 - y0) * (x1 - x0); + } + } + } + let c0 = EF::from_basis_coefficients_fn(|j| c0_acc[j]); + let c2 = EF::from_basis_coefficients_fn(|j| c2_acc[j]); + let c1 = sum - c0.double() - c2; + DensePolynomial::new(vec![c0, c1, c2]) +} + +fn random_kernel_inputs(seed: u64, log_n: usize) -> (Vec, Vec) { + let mut rng = XorShift(seed | 1); + let n = 1 << log_n; + let base: Vec = (0..n) + .map(|_| PFKb::from_fn(|_| KoalaBear::from_u64(rng.next_u64()))) + .collect(); + let ext: Vec = (0..n) + .map(|_| { + PackedQuinticExtensionFieldKB::from_basis_coefficients_fn(|_| { + PFKb::from_fn(|_| KoalaBear::from_u64(rng.next_u64())) + }) + }) + .collect(); + (base, ext) +} + +#[test] +fn rewritten_kernel_matches_reference_randomized() { + for (seed, log_n) in [(1u64, 1usize), (2, 2), (3, 4), (4, 7), (5, 11), (6, 12)] { + let (base, ext) = random_kernel_inputs(seed, log_n); + let sum = EFKb::from_basis_coefficients_fn(|j| KoalaBear::from_u64(seed + j as u64)); + let new = compute_product_sumcheck_polynomial_base_ext_packed::( + &base, &ext, sum, + ); + let reference = reference_base_ext_packed::(&base, &ext, sum); + assert_eq!(new.coeffs, reference.coeffs, "seed={seed} log_n={log_n}"); + } +} + +#[test] +fn rewritten_kernel_matches_reference_boundary() { + let n = 1 << 8; + let mut rng = XorShift(0xDEAD_BEEF); + let base: Vec = (0..n) + .map(|i| match i % 4 { + 0 => PFKb::ZERO, + 1 => PFKb::ONE, + 2 => PFKb::from_fn(|_| KoalaBear::from_u64(u64::MAX)), + _ => PFKb::from_fn(|_| KoalaBear::from_u64(rng.next_u64())), + }) + .collect(); + let ext: Vec = (0..n) + .map(|i| match i % 3 { + 0 => PackedQuinticExtensionFieldKB::ZERO, + 1 => PackedQuinticExtensionFieldKB::ONE, + _ => PackedQuinticExtensionFieldKB::from_basis_coefficients_fn(|_| { + PFKb::from_fn(|_| KoalaBear::from_u64(rng.next_u64())) + }), + }) + .collect(); + let sum = EFKb::ONE; + let new = compute_product_sumcheck_polynomial_base_ext_packed::( + &base, &ext, sum, + ); + let reference = reference_base_ext_packed::(&base, &ext, sum); + assert_eq!(new.coeffs, reference.coeffs); +} + +// ---- full path equality tests ---- + +struct RecordingFs { + rng: XorShift, + polys: Vec>, + challenges: Vec, +} + +impl RecordingFs { + fn new(seed: u64) -> Self { + Self { + rng: XorShift(seed | 1), + polys: vec![], + challenges: vec![], + } + } +} + +impl ChallengeSampler for RecordingFs { + fn sample_vec(&mut self, len: usize) -> Vec { + (0..len) + .map(|_| { + let c = EFKb::from_basis_coefficients_fn(|_| KoalaBear::from_u64(self.rng.next_u64())); + self.challenges.push(c); + c + }) + .collect() + } + fn sample_in_range(&mut self, _bits: usize, _n_samples: usize) -> Vec { + unimplemented!("not used by run_product_sumcheck") + } +} + +impl FSProver for RecordingFs { + fn state(&self) -> String { + format!("recording[{}]", self.polys.len()) + } + fn add_base_scalars(&mut self, _scalars: &[KoalaBear]) {} + fn observe_scalars(&mut self, _scalars: &[KoalaBear]) {} + fn duplex(&mut self) {} + fn pow_grinding(&mut self, _bits: usize) {} + fn hint_merkle_paths_base(&mut self, _paths: Vec>) {} + fn add_sumcheck_polynomial(&mut self, coeffs: &[EFKb], _eq_alpha: Option) { + self.polys.push(coeffs.to_vec()); + } +} + +fn random_full_inputs(seed: u64, n_vars: usize) -> (Vec, Vec) { + let mut rng = XorShift(seed.wrapping_mul(0x9E37_79B9_7F4A_7C15) | 1); + let n_packed = (1usize << n_vars) / PFKb::WIDTH; + let base: Vec = (0..n_packed) + .map(|_| PFKb::from_fn(|_| KoalaBear::from_u64(rng.next_u64()))) + .collect(); + let ext: Vec = (0..n_packed) + .map(|_| { + PackedQuinticExtensionFieldKB::from_basis_coefficients_fn(|_| { + PFKb::from_fn(|_| KoalaBear::from_u64(rng.next_u64())) + }) + }) + .collect(); + (base, ext) +} + +fn true_sum(base: &[PFKb], ext: &[PackedQuinticExtensionFieldKB]) -> EFKb { + let mut acc = PackedQuinticExtensionFieldKB::ZERO; + for (b, e) in base.iter().zip(ext.iter()) { + acc += *e * *b; + } + >::to_ext_iter([acc]).sum::() +} + +fn mle_owned_to_ext_vec(m: &MleOwned) -> Vec { + match m { + MleOwned::Base(v) => v.iter().map(|&x| EFKb::from(x)).collect(), + MleOwned::Extension(v) => v.to_vec(), + MleOwned::BasePacked(v) => v.iter().flat_map(|p| p.as_slice().to_vec()).map(EFKb::from).collect(), + MleOwned::ExtensionPacked(v) => { + >::to_ext_iter(v.iter().copied()) + .collect() + } + } +} + +#[test] +fn eager_full_path_deterministic_under_mock() { + for (seed, n_vars, n_rounds) in [(1u64, 8usize, 2usize), (2, 10, 3), (3, 12, 6)] { + let (base, ext) = random_full_inputs(seed, n_vars); + let sum = true_sum(&base, &ext); + + let mut fs_a = RecordingFs::new(seed); + let out_a = run_product_sumcheck_base_eager::(&base, &ext, &mut fs_a, sum, n_rounds, 0); + let mut fs_b = RecordingFs::new(seed); + let out_b = run_product_sumcheck_base_eager::(&base, &ext, &mut fs_b, sum, n_rounds, 0); + + assert_eq!(fs_a.polys, fs_b.polys, "seed={seed}"); + assert_eq!(fs_a.challenges, fs_b.challenges, "seed={seed}"); + assert_eq!(out_a.0.0, out_b.0.0, "challenge points seed={seed}"); + assert_eq!(out_a.1, out_b.1, "final sum seed={seed}"); + assert_eq!(mle_owned_to_ext_vec(&out_a.2), mle_owned_to_ext_vec(&out_b.2)); + assert_eq!(mle_owned_to_ext_vec(&out_a.3), mle_owned_to_ext_vec(&out_b.3)); + } +} + +#[test] +fn eager_full_path_final_claim_consistent() { + for (seed, n_vars, n_rounds) in [(7u64, 9usize, 2usize), (8, 11, 4)] { + let (base, ext) = random_full_inputs(seed, n_vars); + let sum = true_sum(&base, &ext); + let mut fs = RecordingFs::new(seed); + let (_point, final_sum, pol_a, pol_b) = + run_product_sumcheck_base_eager::(&base, &ext, &mut fs, sum, n_rounds, 0); + let a = mle_owned_to_ext_vec(&pol_a); + let b = mle_owned_to_ext_vec(&pol_b); + let recomposed: EFKb = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum(); + assert_eq!(recomposed, final_sum, "seed={seed} n_vars={n_vars} n_rounds={n_rounds}"); + } +} + +// ---- lazy path equality tests ---- + +fn ef_from(seed: u64, salt: u64) -> EFKb { + EFKb::from_basis_coefficients_fn(|j| KoalaBear::from_u64(seed.wrapping_mul(salt + 1 + j as u64) | 1)) +} + +#[test] +fn e1_lazy_round1_matches_eager_oracle() { + for (seed, n_vars) in [(11u64, 6usize), (12, 7), (13, 9), (14, 11), (15, 14)] { + let (base, ext) = random_full_inputs(seed, n_vars); + let r1 = ef_from(seed, 0xA5); + let sum = ef_from(seed, 0x5A); + + let (poly_lazy, w_lazy) = + fold_weights_and_compute_lazy_round1::(&base, &ext, r1, sum); + let (poly_eager, folded_eager) = fold_and_compute_product_sumcheck_polynomial(&base, &ext, r1, sum, |e| { + >::to_ext_iter([e]) + .collect::>() + }); + + assert_eq!(poly_lazy.coeffs, poly_eager.coeffs, "seed={seed} n_vars={n_vars}"); + assert_eq!( + w_lazy.as_slice(), + folded_eager[1].as_slice(), + "W' seed={seed} n_vars={n_vars}" + ); + } +} + +#[test] +fn e2_fused_round2_entry_matches_fold_twice_oracle() { + for (seed, n_vars) in [(21u64, 8usize), (22, 9), (23, 10), (24, 13)] { + let (base, ext) = random_full_inputs(seed, n_vars); + let r1 = ef_from(seed, 0x11); + let r2 = ef_from(seed, 0x22); + let sum = ef_from(seed, 0x33); + + let (_p1, folded1) = fold_and_compute_product_sumcheck_polynomial(&base, &ext, r1, sum, |e| { + >::to_ext_iter([e]) + .collect::>() + }); + let (p2_oracle, folded2) = fold_and_compute_product_sumcheck_polynomial( + folded1[0].as_slice(), + folded1[1].as_slice(), + r2, + sum, + |e| { + >::to_ext_iter([e]) + .collect::>() + }, + ); + + let (_p1_lazy, w_folded) = + fold_weights_and_compute_lazy_round1::(&base, &ext, r1, sum); + let (p2_lazy, e2, w2) = materialize_double_fold_and_compute_round2::( + &base, + &w_folded, + r1, + r2, + sum, + true, + |e| { + >::to_ext_iter([e]) + .collect::>() + }, + ); + + assert_eq!( + p2_lazy.unwrap().coeffs, + p2_oracle.coeffs, + "poly seed={seed} n_vars={n_vars}" + ); + assert_eq!(e2.as_slice(), folded2[0].as_slice(), "E'' seed={seed} n_vars={n_vars}"); + assert_eq!(w2.as_slice(), folded2[1].as_slice(), "W'' seed={seed} n_vars={n_vars}"); + + let (none_poly, e2b, w2b) = materialize_double_fold_and_compute_round2::( + &base, + &w_folded, + r1, + r2, + sum, + false, + |e| { + >::to_ext_iter([e]) + .collect::>() + }, + ); + assert!(none_poly.is_none()); + assert_eq!(e2b.as_slice(), e2.as_slice()); + assert_eq!(w2b.as_slice(), w2.as_slice()); + } +} + +#[test] +fn e3_full_lazy_vs_eager_transcript_identical() { + for n_vars in [8usize, 9, 11, 14] { + for n_rounds in [2usize, 3, 6] { + if n_rounds > n_vars { + continue; + } + let seed = (n_vars * 100 + n_rounds) as u64; + let (base, ext) = random_full_inputs(seed, n_vars); + let sum = true_sum(&base, &ext); + + let mut fs_eager = RecordingFs::new(seed); + let out_eager = + run_product_sumcheck_base_eager::(&base, &ext, &mut fs_eager, sum, n_rounds, 0); + let mut fs_lazy = RecordingFs::new(seed); + let out_lazy = + run_product_sumcheck_base_lazy::(&base, &ext, &mut fs_lazy, sum, n_rounds, 0); + + assert_eq!( + fs_lazy.polys, fs_eager.polys, + "polys n_vars={n_vars} n_rounds={n_rounds}" + ); + assert_eq!( + fs_lazy.challenges, fs_eager.challenges, + "challenges n_vars={n_vars} n_rounds={n_rounds}" + ); + assert_eq!(out_lazy.0.0, out_eager.0.0, "point n_vars={n_vars} n_rounds={n_rounds}"); + assert_eq!(out_lazy.1, out_eager.1, "sum n_vars={n_vars} n_rounds={n_rounds}"); + assert_eq!( + mle_owned_to_ext_vec(&out_lazy.2), + mle_owned_to_ext_vec(&out_eager.2), + "fold A n_vars={n_vars} n_rounds={n_rounds}" + ); + assert_eq!( + mle_owned_to_ext_vec(&out_lazy.3), + mle_owned_to_ext_vec(&out_eager.3), + "fold B n_vars={n_vars} n_rounds={n_rounds}" + ); + } + } +} diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index b046c57e6..a08cdbdae 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -176,17 +176,18 @@ pub fn prove_execution( let log_n_rows = tables_log_heights[table]; let n_constraints = table.n_constraints(); let bus_numerator_value = logup_statements.bus_numerators_values[table]; - let bus_denominator_value = logup_statements.bus_denominators_values[table]; + let bus_denominator_values = &logup_statements.bus_denominators_values[table]; let signed_numerator = bus_numerator_value * match table.bus_interactions()[0].direction { BusDirection::Pull => EF::NEG_ONE, BusDirection::Push => EF::ONE, }; - // Each table consumes a disjoint range of alpha powers; alpha^offset weights the bus - // numerator (multiplicity), alpha^{offset+1} weights the bus fingerprint, alpha^{offset+2..} - // weight the remaining AIR constraints. - let bus_final_value = air_alpha_powers[alpha_offset] * signed_numerator - + air_alpha_powers[alpha_offset + 1] * (logup_c - bus_denominator_value); + // Alpha layout: slot 0 = multiplicity, slot 1 = Column-bus fingerprint, + // then deferred-claim slots, then algebra constraints. + let mut bus_final_value = air_alpha_powers[alpha_offset] * signed_numerator; + for (i, den) in bus_denominator_values.iter().enumerate() { + bus_final_value += air_alpha_powers[alpha_offset + 1 + i] * (logup_c - *den); + } let eq_suffix = from_end(gkr_point, log_n_rows).to_vec(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 5f630df67..4f8e8b2a5 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -163,6 +163,7 @@ pub fn get_execution_trace( padding_zero_vec_ptr, null_poseidon_16_hash_ptr, bytecode.ending_pc(), + memory_padded[0], floor, ); } @@ -180,6 +181,7 @@ fn pad_table( zero_vec_ptr: usize, null_poseidon_16_hash_ptr: usize, ending_pc: usize, + mem0: F, min_log_n_rows: usize, ) { let trace = traces.get_mut(table).unwrap(); @@ -193,7 +195,7 @@ fn pad_table( trace.non_padded_n_rows = h; trace.log_n_rows = log2_ceil_usize(h + 1).max(min_log_n_rows); let n_rows = 1 << trace.log_n_rows; - let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr, ending_pc); + let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr, ending_pc, mem0); parallel::par_for_each_mut(&mut trace.columns, |i, col| { assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 3c16b412a..7502b4ff3 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -115,14 +115,22 @@ pub fn verify_execution( for table in ALL_TABLES { let n_constraints = table.n_constraints(); let bus_numerator_value = logup_statements.bus_numerators_values[&table]; - let bus_denominator_value = logup_statements.bus_denominators_values[&table]; + let bus_denominator_values = &logup_statements.bus_denominators_values[&table]; let signed_numerator = bus_numerator_value * match table.bus_interactions()[0].direction { BusDirection::Pull => EF::NEG_ONE, BusDirection::Push => EF::ONE, }; - initial_sum += air_alpha_powers[alpha_offset] * signed_numerator - + air_alpha_powers[alpha_offset + 1] * (logup_c - bus_denominator_value); + // slot 0 = multiplicity, slot 1 = Column-bus fingerprint, then deferred-claim slots. + initial_sum += air_alpha_powers[alpha_offset] * signed_numerator; + let n_claim_slots = bus_denominator_values.len(); + debug_assert_eq!( + n_claim_slots, + 1 + table.bus_interactions().iter().filter(|b| b.deferred_claim).count() + ); + for (i, den) in bus_denominator_values.iter().enumerate() { + initial_sum += air_alpha_powers[alpha_offset + 1 + i] * (logup_c - *den); + } let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + n_constraints].to_vec(); verify_data.push(TableVerifyData { diff --git a/crates/lean_vm/src/tables/execution/air.rs b/crates/lean_vm/src/tables/execution/air.rs index eaca5f015..01e72db97 100644 --- a/crates/lean_vm/src/tables/execution/air.rs +++ b/crates/lean_vm/src/tables/execution/air.rs @@ -1,36 +1,37 @@ -use crate::{EF, ExecutionTable, ExtraDataForBuses, eval_bus_virtual}; +use crate::{EF, ExecutionTable, ExtraDataForBuses, LOGUP_MEMORY_DOMAINSEP, eval_bus_data_only, eval_bus_virtual}; use backend::*; -pub const N_RUNTIME_COLUMNS: usize = 8; +pub const N_RUNTIME_COLUMNS: usize = 5; pub const N_INSTRUCTION_COLUMNS: usize = 12; pub const N_TOTAL_EXECUTION_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_RUNTIME_COLUMNS; // Committed columns (IMPORTANT: they must be the first columns) +// ADDR_A/B/C are virtual (temporary) — derived from committed columns below. pub const EXEC_COL_PC: usize = 0; pub const EXEC_COL_FP: usize = 1; -pub const EXEC_COL_ADDR_A: usize = 2; -pub const EXEC_COL_ADDR_B: usize = 3; -pub const EXEC_COL_ADDR_C: usize = 4; -pub const EXEC_COL_VALUE_A: usize = 5; -pub const EXEC_COL_VALUE_B: usize = 6; -pub const EXEC_COL_VALUE_C: usize = 7; +pub const EXEC_COL_VALUE_A: usize = 2; +pub const EXEC_COL_VALUE_B: usize = 3; +pub const EXEC_COL_VALUE_C: usize = 4; // Decoded instruction columns -pub const EXEC_COL_OPERAND_A: usize = 8; -pub const EXEC_COL_OPERAND_B: usize = 9; -pub const EXEC_COL_OPERAND_C: usize = 10; -pub const EXEC_COL_FLAG_A: usize = 11; -pub const EXEC_COL_FLAG_B: usize = 12; -pub const EXEC_COL_FLAG_C: usize = 13; -pub const EXEC_COL_FLAG_C_FP: usize = 14; -pub const EXEC_COL_FLAG_AB_FP: usize = 15; -pub const EXEC_COL_FLAG_MUL: usize = 16; -pub const EXEC_COL_FLAG_JUMP: usize = 17; -pub const EXEC_COL_AUX_1: usize = 18; -pub const EXEC_COL_AUX_2: usize = 19; - -// Temporary columns (stored to avoid duplicate computations) -pub const N_TEMPORARY_EXEC_COLUMNS: usize = 4; +pub const EXEC_COL_OPERAND_A: usize = 5; +pub const EXEC_COL_OPERAND_B: usize = 6; +pub const EXEC_COL_OPERAND_C: usize = 7; +pub const EXEC_COL_FLAG_A: usize = 8; +pub const EXEC_COL_FLAG_B: usize = 9; +pub const EXEC_COL_FLAG_C: usize = 10; +pub const EXEC_COL_FLAG_C_FP: usize = 11; +pub const EXEC_COL_FLAG_AB_FP: usize = 12; +pub const EXEC_COL_FLAG_MUL: usize = 13; +pub const EXEC_COL_FLAG_JUMP: usize = 14; +pub const EXEC_COL_AUX_1: usize = 15; +pub const EXEC_COL_AUX_2: usize = 16; + +// Temporary columns (stored to avoid duplicate computations; NOT committed) +pub const N_TEMPORARY_EXEC_COLUMNS: usize = 7; +pub const EXEC_COL_ADDR_A: usize = 17; +pub const EXEC_COL_ADDR_B: usize = 18; +pub const EXEC_COL_ADDR_C: usize = 19; pub const EXEC_COL_FLAG_PRECOMPILE: usize = 20; pub const EXEC_COL_NU_A: usize = 21; pub const EXEC_COL_NU_B: usize = 22; @@ -45,11 +46,15 @@ impl Air for ExecutionTable { fn degree_air(&self) -> usize { 5 } + // C2 not profitable: 14-constraint eval is too cheap to amortize the cache overhead. + fn c2_table_profitable(&self) -> bool { + false + } fn n_shift_columns(&self) -> usize { 2 } fn n_constraints(&self) -> usize { - 14 + 13 } #[inline] @@ -76,7 +81,6 @@ impl Air for ExecutionTable { let (value_a, value_b, value_c) = (flat[EXEC_COL_VALUE_A], flat[EXEC_COL_VALUE_B], flat[EXEC_COL_VALUE_C]); let pc = flat[EXEC_COL_PC]; let fp = flat[EXEC_COL_FP]; - let (addr_a, addr_b, addr_c) = (flat[EXEC_COL_ADDR_A], flat[EXEC_COL_ADDR_B], flat[EXEC_COL_ADDR_C]); let one_minus_flag_a_and_flag_ab_fp = -(flag_a + flag_ab_fp - AB::F::ONE); let one_minus_flag_b_and_flag_ab_fp = -(flag_b + flag_ab_fp - AB::F::ONE); @@ -96,22 +100,27 @@ impl Air for ExecutionTable { let flag_deref = (aux_1 * (aux_1 - AB::F::ONE)).halve(); let flag_precompile = -(flag_add + flag_mul + flag_deref + flag_jump - AB::F::ONE); + // Virtual addresses: closed-form expressions of committed columns. + let addr_a = one_minus_flag_a_and_flag_ab_fp * fp_plus_operand_a; + let addr_b = one_minus_flag_b_and_flag_ab_fp * fp_plus_operand_b + flag_deref * (value_a + operand_b); + let addr_c = one_minus_flag_c_and_flag_c_fp * fp_plus_operand_c; + if BUS { eval_bus_virtual::(builder, extra_data, flag_precompile, aux_2, &[nu_a, nu_b, nu_c]); + // Deferred memory-bus claims (order must match verifier's initial_sum loop). + eval_bus_data_only::(builder, extra_data, LOGUP_MEMORY_DOMAINSEP, &[addr_a, value_a]); + eval_bus_data_only::(builder, extra_data, LOGUP_MEMORY_DOMAINSEP, &[addr_b, value_b]); + eval_bus_data_only::(builder, extra_data, LOGUP_MEMORY_DOMAINSEP, &[addr_c, value_c]); } else { builder.declare_values(&[flag_precompile]); builder.declare_values(&[nu_a, nu_b, nu_c, aux_2]); + builder.declare_values(&[addr_a, addr_b, addr_c]); } - builder.assert_zero(one_minus_flag_a_and_flag_ab_fp * (addr_a - fp_plus_operand_a)); - builder.assert_zero(one_minus_flag_b_and_flag_ab_fp * (addr_b - fp_plus_operand_b)); - builder.assert_zero(one_minus_flag_c_and_flag_c_fp * (addr_c - fp_plus_operand_c)); - builder.assert_zero(flag_add * (nu_b - (nu_a + nu_c))); builder.assert_zero(flag_mul * (nu_b - nu_a * nu_c)); - // DEREF: addr_B = value_A + operand_B, result in value_B, compared to nu_C - builder.assert_zero(flag_deref * (addr_b - (value_a + operand_b))); + // DEREF: result in value_B, compared to nu_C builder.assert_zero(flag_deref * (value_b - nu_c)); let jump_and_condition = flag_jump * nu_a; diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 471c1fdc6..789049ecd 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -34,6 +34,7 @@ impl TableT for ExecutionTable { .map(|i| BusData::Column(N_RUNTIME_COLUMNS + i)) .chain(std::iter::once(BusData::Column(EXEC_COL_PC))) .collect(), + deferred_claim: false, }; let precompile_bus = BusInteraction { direction: BusDirection::Push, @@ -44,17 +45,34 @@ impl TableT for ExecutionTable { BusData::Column(EXEC_COL_NU_B), BusData::Column(EXEC_COL_NU_C), ], + deferred_claim: false, }; // Convention shared with the other tables: the unique Multiplicity::Column bus // comes first; everything that follows is Multiplicity::One. let mut buses = vec![precompile_bus, bytecode_lookup]; - buses.extend(memory_lookups_consecutive(EXEC_COL_ADDR_A, EXEC_COL_VALUE_A, 1)); - buses.extend(memory_lookups_consecutive(EXEC_COL_ADDR_B, EXEC_COL_VALUE_B, 1)); - buses.extend(memory_lookups_consecutive(EXEC_COL_ADDR_C, EXEC_COL_VALUE_C, 1)); + // Deferred-claim memory buses (virtual address columns, order: A, B, C). + buses.extend(memory_lookups_consecutive_with_claim( + EXEC_COL_ADDR_A, + EXEC_COL_VALUE_A, + 1, + true, + )); + buses.extend(memory_lookups_consecutive_with_claim( + EXEC_COL_ADDR_B, + EXEC_COL_VALUE_B, + 1, + true, + )); + buses.extend(memory_lookups_consecutive_with_claim( + EXEC_COL_ADDR_C, + EXEC_COL_VALUE_C, + 1, + true, + )); buses } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize, ending_pc: usize) -> Vec { + fn padding_row(&self, _zero_vec_ptr: usize, _null_hash_ptr: usize, ending_pc: usize, mem0: F) -> Vec { let mut padding_row = vec![F::ZERO; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; padding_row[EXEC_COL_PC] = F::from_usize(ending_pc); padding_row[EXEC_COL_FLAG_JUMP] = F::ONE; @@ -65,9 +83,10 @@ impl TableT for ExecutionTable { padding_row[EXEC_COL_FLAG_C_FP] = F::ONE; // this is kind of arbitrary padding_row[EXEC_COL_NU_A] = F::ONE; // we always jump here (self-loop, so condition = nu_a = 1) padding_row[EXEC_COL_NU_B] = F::from_usize(ending_pc); // nu_b = jump dest = ending_pc - padding_row[EXEC_COL_ADDR_A] = F::from_usize(zero_vec_ptr); - padding_row[EXEC_COL_ADDR_B] = F::from_usize(zero_vec_ptr); - padding_row[EXEC_COL_ADDR_C] = F::from_usize(zero_vec_ptr); + // Virtual addresses evaluate to 0 on padding rows, so VALUE_* must = memory[0]. + padding_row[EXEC_COL_VALUE_A] = mem0; + padding_row[EXEC_COL_VALUE_B] = mem0; + padding_row[EXEC_COL_VALUE_C] = mem0; padding_row } @@ -83,3 +102,50 @@ impl TableT for ExecutionTable { unreachable!() } } + +#[cfg(test)] +mod h9_layout_tests { + use super::*; + + /// Padding row: virtual addresses all zero, VALUE_* = mem0. + #[test] + fn padding_row_satisfies_virtual_addresses() { + let mem0 = F::from_usize(424242); + let row = ExecutionTable::.padding_row(999, 998, 7, mem0); + assert_eq!(row.len(), N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS); + assert_eq!(N_TOTAL_EXECUTION_COLUMNS, 17); + assert_eq!(N_TEMPORARY_EXEC_COLUMNS, 7); + + let f = |i: usize| row[i]; + let one = F::ONE; + let aux_1 = f(EXEC_COL_AUX_1); + let flag_deref = (aux_1 * (aux_1 - one)).halve(); + let addr_a = (one - f(EXEC_COL_FLAG_A) - f(EXEC_COL_FLAG_AB_FP)) * (f(EXEC_COL_FP) + f(EXEC_COL_OPERAND_A)); + let addr_b = (one - f(EXEC_COL_FLAG_B) - f(EXEC_COL_FLAG_AB_FP)) * (f(EXEC_COL_FP) + f(EXEC_COL_OPERAND_B)) + + flag_deref * (f(EXEC_COL_VALUE_A) + f(EXEC_COL_OPERAND_B)); + let addr_c = (one - f(EXEC_COL_FLAG_C) - f(EXEC_COL_FLAG_C_FP)) * (f(EXEC_COL_FP) + f(EXEC_COL_OPERAND_C)); + assert_eq!(addr_a, F::ZERO); + assert_eq!(addr_b, F::ZERO); + assert_eq!(addr_c, F::ZERO); + // temporaries mirror the closed forms on the padding row + assert_eq!(f(EXEC_COL_ADDR_A), F::ZERO); + assert_eq!(f(EXEC_COL_ADDR_B), F::ZERO); + assert_eq!(f(EXEC_COL_ADDR_C), F::ZERO); + // bus pushes are (0, mem0) + assert_eq!(f(EXEC_COL_VALUE_A), mem0); + assert_eq!(f(EXEC_COL_VALUE_B), mem0); + assert_eq!(f(EXEC_COL_VALUE_C), mem0); + // memory buses reference the temporary addr columns + committed value columns + let buses = ExecutionTable::.bus_interactions(); + let mem_groups = memory_lookup_groups(&buses); + assert_eq!(mem_groups.len(), 3); + for (g, (a, v)) in mem_groups.iter().zip([ + (EXEC_COL_ADDR_A, EXEC_COL_VALUE_A), + (EXEC_COL_ADDR_B, EXEC_COL_VALUE_B), + (EXEC_COL_ADDR_C, EXEC_COL_VALUE_C), + ]) { + assert_eq!(g.idx_col, a); + assert_eq!(g.value_cols, vec![v]); + } + } +} diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index 6f085cd7e..e7f2868c7 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -97,6 +97,7 @@ impl TableT for ExtensionOpPrecompile { BusData::Column(COL_IDX_B), BusData::Column(COL_IDX_RES), ], + deferred_claim: false, }]; buses.extend(memory_lookups_consecutive(COL_IDX_A, COL_V_A, DIMENSION)); buses.extend(memory_lookups_consecutive(COL_IDX_B, COL_V_B, DIMENSION)); @@ -108,7 +109,7 @@ impl TableT for ExtensionOpPrecompile { self.n_columns() + 2 // +2 for COL_MULTIPLICITY_EXTENSION_OP and COL_DOMAINSEP_EXTENSION_OP (non-AIR, used in bus logup) } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize, _ending_pc: usize) -> Vec { + fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize, _ending_pc: usize, _mem0: F) -> Vec { let mut row = vec![F::ZERO; self.n_columns_total()]; row[COL_FLAG_START] = F::ONE; row[COL_LEN] = F::ONE; diff --git a/crates/lean_vm/src/tables/poseidon/mod.rs b/crates/lean_vm/src/tables/poseidon/mod.rs index 0caa7dfa6..9b41c2626 100644 --- a/crates/lean_vm/src/tables/poseidon/mod.rs +++ b/crates/lean_vm/src/tables/poseidon/mod.rs @@ -118,6 +118,23 @@ const AUX_COLS_PER_ROW: usize = num_cols_poseidon_8() - POSEIDON_8_COL_ROUND_STA // decomposition so only 2 cols/round are emitted. fn mds_vec_mul(state: &[F; WIDTH]) -> [F; WIDTH] { + // u128 accumulation: MDS coefficients <= 9, so sum < 2^71. Reduce once. + let s: [u128; WIDTH] = std::array::from_fn(|j| state[j].as_canonical_u64() as u128); + let mut out = [F::ZERO; WIDTH]; + for i in 0..WIDTH { + let mut acc: u128 = 0; + for j in 0..WIDTH { + acc += MDS8_ROW[(j + WIDTH - i) % WIDTH] as u128 * s[j]; + } + let lo = acc as u64; + let hi = (acc >> 64) as u64; + out[i] = F::from_u64(lo) + F::from_u64(hi * 0xFFFF_FFFF); + } + out +} + +#[cfg(test)] +fn mds_vec_mul_oracle(state: &[F; WIDTH]) -> [F; WIDTH] { let mut out = [F::ZERO; WIDTH]; for i in 0..WIDTH { let mut acc = state[0] * F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); @@ -213,6 +230,106 @@ pub fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; WIDTH]) { (aux, state) } +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +mod packed_witness { + use super::*; + use backend::{PackedGoldilocksAVX512 as P, mds_mul_simd}; + + /// Rows processed per packed call — one row per AVX512 lane. + pub const WITNESS_LANES: usize = 8; + + #[inline(always)] + fn sbox7_p(x: P) -> P { + let x2 = x * x; + let x4 = x2 * x2; + x4 * x2 * x + } + + #[inline(always)] + fn bcast(c: F) -> P { + P([c; WITNESS_LANES]) + } + + /// Row-major convenience wrapper over [`compute_poseidon8_witness_packed`]: + /// transposes 8 scalar input rows into packed lanes first. + pub fn compute_poseidon8_witness_x8(inputs: &[[F; WIDTH]; WITNESS_LANES]) -> ([P; AUX_COLS_PER_ROW], [P; WIDTH]) { + let state: [P; WIDTH] = std::array::from_fn(|j| P(std::array::from_fn(|l| inputs[l][j]))); + compute_poseidon8_witness_packed(state) + } + + /// 8-lane packed variant of [`compute_poseidon8_witness`]. + pub fn compute_poseidon8_witness_packed(mut state: [P; WIDTH]) -> ([P; AUX_COLS_PER_ROW], [P; WIDTH]) { + let c = get_partial_constants(); + let mut aux = [P::ZERO; AUX_COLS_PER_ROW]; + let mut k = 0; + + // Initial full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for (j, s) in state.iter_mut().enumerate() { + *s = sbox7_p(*s + bcast(rc[j])); + } + state = mds_mul_simd(state); + aux[k..k + WIDTH].copy_from_slice(&state); + k += WIDTH; + } + + // Partial phase: absorb first_round_constants, apply m_i, then sparse rounds. + for (j, s) in state.iter_mut().enumerate() { + *s += bcast(c.first_round_constants[j]); + } + { + let mut after = [P::ZERO; WIDTH]; + for (i, dst) in after.iter_mut().enumerate() { + let mut acc = P::ZERO; + for (j, sj) in state.iter().enumerate() { + acc += *sj * bcast(c.m_i[i][j]); + } + *dst = acc; + } + state = after; + } + + for r in 0..SPARSE_PARTIAL_ROUNDS { + let post_sbox = sbox7_p(state[0]); + aux[k] = post_sbox; + k += 1; + + state[0] = if r < SPARSE_PARTIAL_ROUNDS - 1 { + post_sbox + bcast(c.round_constants[r]) + } else { + post_sbox + }; + + let old_s0 = state[0]; + let mut new_s0 = P::ZERO; + for (j, sj) in state.iter().enumerate() { + new_s0 += *sj * bcast(c.sparse_first_row[r][j]); + } + state[0] = new_s0; + for (i, s) in state.iter_mut().enumerate().skip(1) { + *s += old_s0 * bcast(c.v[r][i - 1]); + } + } + + // Terminal full rounds. + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { + let abs = POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + round; + for (j, s) in state.iter_mut().enumerate() { + *s = sbox7_p(*s + bcast(GOLDILOCKS_POSEIDON1_RC_8[abs][j])); + } + state = mds_mul_simd(state); + aux[k..k + WIDTH].copy_from_slice(&state); + k += WIDTH; + } + + debug_assert_eq!(k, AUX_COLS_PER_ROW); + (aux, state) + } +} + +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub use packed_witness::{WITNESS_LANES, compute_poseidon8_witness_packed, compute_poseidon8_witness_x8}; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Poseidon8Precompile; @@ -239,6 +356,7 @@ impl TableT for Poseidon8Precompile { BusData::Column(POSEIDON_8_COL_NU_B), BusData::Column(POSEIDON_8_COL_NU_C), ], + deferred_claim: false, }]; buses.extend(memory_lookups_consecutive( POSEIDON_8_COL_ADDR_LEFT_LO, @@ -263,7 +381,7 @@ impl TableT for Poseidon8Precompile { buses } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, _ending_pc: usize) -> Vec { + fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, _ending_pc: usize, _mem0: F) -> Vec { let mut row = vec![F::ZERO; num_cols_total_poseidon_8()]; row[POSEIDON_8_COL_MULTIPLICITY] = F::ZERO; row[POSEIDON_8_COL_NU_B] = F::from_usize(zero_vec_ptr); @@ -336,7 +454,8 @@ impl TableT for Poseidon8Precompile { .get_slice_into(left_second_addr, &mut input[HALF_DIGEST_LEN..DIGEST])?; ctx.memory.get_slice_into(arg_b.to_usize(), &mut input[DIGEST..])?; - let (aux, perm_state) = compute_poseidon8_witness(input); + // Round witness columns filled later by fill_trace_poseidon_8. + let perm_state = poseidon8_permute(input); // `output_cols` are the WIDTH output trace columns. For permute rows they // hold the raw permutation state; for compression rows `out_lo` @@ -382,9 +501,8 @@ impl TableT for Poseidon8Precompile { for (i, value) in output_cols.iter().enumerate() { trace.columns[POSEIDON_8_COL_OUT_LO + i].push(*value); } - for (i, value) in aux.iter().enumerate() { - trace.columns[POSEIDON_8_COL_ROUND_START + i].push(*value); - } + // The aux columns (POSEIDON_8_COL_ROUND_START..) stay empty here — + // `fill_trace_poseidon_8` recomputes them from the input columns. // Non-committed columns trace.columns[POSEIDON_8_COL_NU_A].push(arg_a); let domainsep = POSEIDON_DOMAINSEP_BASE @@ -427,12 +545,49 @@ impl Air for Poseidon8Precompile { // gates are at most degree 2; the round gates dominate at degree 7. 8 } + fn degree_z(&self) -> usize { + // True max is 7 (S-box x^7; all column refs linear). See degree_census test. + 7 + } fn n_shift_columns(&self) -> usize { 0 } fn n_constraints(&self) -> usize { poseidon8_n_constraints(BUS) } + // Bus-only eval for the C2 seed round (see c2_bus_seed.rs test). + fn eval_bus_only(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + if !BUS { + self.eval(builder, extra_data); + return; + } + let flat = builder.flat(); + let multiplicity = flat[POSEIDON_8_COL_MULTIPLICITY]; + let nu_b = flat[POSEIDON_8_COL_NU_B]; + let nu_c = flat[POSEIDON_8_COL_NU_C]; + let flag_out4 = flat[POSEIDON_8_COL_FLAG_OUT4]; + let flag_left = flat[POSEIDON_8_COL_FLAG_LEFT]; + let flag_permute = flat[POSEIDON_8_COL_FLAG_PERMUTE]; + let offset_left = flat[POSEIDON_8_COL_OFFSET_LEFT]; + let addr_left_hi = flat[POSEIDON_8_COL_ADDR_LEFT_HI]; + + let domainsep_reconstructed = AB::IF::from_usize(POSEIDON_DOMAINSEP_BASE) + + flag_permute * AB::F::from_usize(POSEIDON_FLAG_PERMUTE_SHIFT) + + flag_out4 * AB::F::from_usize(POSEIDON_FLAG_OUT4_SHIFT) + + flag_left * AB::F::from_usize(POSEIDON_FLAG_LEFT_SHIFT) + + flag_left * offset_left * AB::F::from_usize(POSEIDON_OFFSET_LEFT_SHIFT); + let one_minus_flag_left = AB::IF::ONE - flag_left; + let nu_a = addr_left_hi - one_minus_flag_left * AB::F::from_usize(HALF_DIGEST_LEN); + + eval_bus_virtual::( + builder, + extra_data, + multiplicity, + domainsep_reconstructed, + &[nu_a, nu_b, nu_c], + ); + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let c = get_partial_constants(); @@ -638,3 +793,60 @@ impl Air for Poseidon8Precompile { } } } + +#[cfg(test)] +mod mds_u128_tests { + use super::*; + + fn xorshift(s: &mut u64) -> u64 { + *s ^= *s << 13; + *s ^= *s >> 7; + *s ^= *s << 17; + *s + } + + #[test] + fn mds_u128_equals_oracle() { + let mut seed = 0x9E3779B97F4A7C15u64; + // boundary patterns incl. non-canonical representatives + let p = 0xFFFF_FFFF_0000_0001u64; + let specials = [0u64, 1, p - 1, p, p + 1, u64::MAX, 0xFFFF_FFFF, 1 << 63]; + for &v in &specials { + let st: [F; WIDTH] = std::array::from_fn(|_| F::from_u64(v)); + assert_eq!(mds_vec_mul(&st), mds_vec_mul_oracle(&st), "special {v:#x}"); + } + for _ in 0..100_000 { + let st: [F; WIDTH] = std::array::from_fn(|_| F::from_u64(xorshift(&mut seed))); + let a = mds_vec_mul(&st); + let b = mds_vec_mul_oracle(&st); + assert_eq!(a, b); + } + } + + #[test] + #[ignore] + fn mds_kernel_ab() { + let mut seed = 7u64; + let states: Vec<[F; WIDTH]> = (0..200_000) + .map(|_| std::array::from_fn(|_| F::from_u64(xorshift(&mut seed)))) + .collect(); + let t = std::time::Instant::now(); + let mut sink = F::ZERO; + for st in &states { + sink += mds_vec_mul(st)[0]; + } + let new_dt = t.elapsed(); + let t = std::time::Instant::now(); + for st in &states { + sink += mds_vec_mul_oracle(st)[0]; + } + let old_dt = t.elapsed(); + println!( + "mds new {:?} old {:?} speedup {:.2}x sink={sink:?}", + new_dt, + old_dt, + old_dt.as_secs_f64() / new_dt.as_secs_f64() + ); + } + +} diff --git a/crates/lean_vm/src/tables/poseidon/trace_gen.rs b/crates/lean_vm/src/tables/poseidon/trace_gen.rs index c59fdcc5a..7c38f0340 100644 --- a/crates/lean_vm/src/tables/poseidon/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon/trace_gen.rs @@ -1,18 +1,63 @@ use tracing::instrument; -use crate::F; -use backend::{ArenaVec, PrimeCharacteristicRing}; +use super::*; -// `execute()` writes every column for each active row (I/O + all per-round -// witness cols), and `padding_row()` supplies the zero-input witness for -// trailing padding. This pass just equalises column lengths in case any got -// out of sync during trace construction. +// Recompute per-round witness columns from the input columns, 8 rows per +// packed call, parallel over disjoint row ranges. #[instrument(name = "generate Poseidon8 AIR trace", skip_all)] pub fn fill_trace_poseidon_8(trace: &mut [ArenaVec]) { let n = trace.iter().map(|col| col.len()).max().unwrap_or(0); - for col in trace.iter_mut() { + parallel::par_for_each_mut(trace, |_, col| { if col.len() != n { col.resize(n, F::ZERO); } + }); + if n == 0 { + return; + } + + let (head, tail) = trace.split_at_mut(POSEIDON_8_COL_ROUND_START); + let inputs: [&[F]; WIDTH] = std::array::from_fn(|j| &head[POSEIDON_8_COL_INPUT_START + j][..n]); + // Disjoint-row-range writes into the aux columns, partitioned by task + // (same pattern as the segment replay in runner.rs). + let aux: Vec> = tail[..AUX_COLS_PER_ROW] + .iter_mut() + .map(|col| parallel::SendPtr(col.as_mut_ptr())) + .collect(); + + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] + { + use backend::PackedGoldilocksAVX512 as P; + let n_blocks = n / WITNESS_LANES; + parallel::for_each_chunk(n_blocks, |start, end| { + for blk in start..end { + let i = blk * WITNESS_LANES; + let state: [P; WIDTH] = std::array::from_fn(|j| P(inputs[j][i..i + WITNESS_LANES].try_into().unwrap())); + let (vals, _) = compute_poseidon8_witness_packed(state); + for (k, v) in vals.iter().enumerate() { + // SAFETY: row range [i, i+WITNESS_LANES) belongs to this + // task alone; every aux column has length n. + unsafe { std::ptr::copy_nonoverlapping(v.0.as_ptr(), aux[k].0.add(i), WITNESS_LANES) }; + } + } + }); + fill_aux_scalar(&inputs, &aux, n_blocks * WITNESS_LANES, n); + } + + #[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))] + parallel::for_each_chunk(n, |start, end| fill_aux_scalar(&inputs, &aux, start, end)); +} + +// Rows are addressed by index across 8 input columns and 86 raw column +// pointers at once — iterator forms don't apply. +#[allow(clippy::needless_range_loop)] +fn fill_aux_scalar(inputs: &[&[F]; WIDTH], aux: &[parallel::SendPtr], start: usize, end: usize) { + for i in start..end { + let input: [F; WIDTH] = std::array::from_fn(|j| inputs[j][i]); + let (vals, _) = compute_poseidon8_witness(input); + for (k, v) in vals.iter().enumerate() { + // SAFETY: row `i` is written by exactly one task; columns have length >= end. + unsafe { *aux[k].0.add(i) = *v }; + } } } diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 75013f552..f0127b952 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -67,8 +67,8 @@ impl TableT for Table { fn bus_interactions(&self) -> Vec { delegate_to_inner!(self, bus_interactions) } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, ending_pc: usize) -> Vec> { - delegate_to_inner!(self, padding_row, zero_vec_ptr, null_hash_ptr, ending_pc) + fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, ending_pc: usize, mem0: PF) -> Vec> { + delegate_to_inner!(self, padding_row, zero_vec_ptr, null_hash_ptr, ending_pc, mem0) } fn execute( &self, diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index 54d931db4..7a16712b8 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -53,6 +53,10 @@ pub struct BusInteraction { pub multiplicity: BusMultiplicity, pub domainsep: BusData, pub data: Vec, + /// When true, the bus data references virtual columns: the prover sends one + /// composite denominator eval, and the AIR sumcheck re-derives the fingerprint. + /// Only Multiplicity::One buses may be deferred. + pub deferred_claim: bool, } impl BusInteraction { @@ -62,6 +66,16 @@ impl BusInteraction { } pub fn memory_lookups_consecutive(idx_col: ColIndex, values_start: ColIndex, n: usize) -> Vec { + memory_lookups_consecutive_with_claim(idx_col, values_start, n, false) +} + +/// `deferred: true` marks the lookups as deferred-claim buses (see `BusInteraction::deferred_claim`). +pub fn memory_lookups_consecutive_with_claim( + idx_col: ColIndex, + values_start: ColIndex, + n: usize, + deferred: bool, +) -> Vec { (0..n) .map(|i| BusInteraction { direction: BusDirection::Push, @@ -71,6 +85,7 @@ pub fn memory_lookups_consecutive(idx_col: ColIndex, values_start: ColIndex, n: BusData::ColumnPlusConstant(idx_col, i), BusData::Column(values_start + i), ], + deferred_claim: deferred, }) .collect() } @@ -205,7 +220,7 @@ pub trait TableT: Air { fn name(&self) -> &'static str; fn table(&self) -> Table; fn bus_interactions(&self) -> Vec; - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, ending_pc: usize) -> Vec; + fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, ending_pc: usize, mem0: F) -> Vec; fn execute( &self, arg_a: F, diff --git a/crates/lean_vm/src/tables/utils.rs b/crates/lean_vm/src/tables/utils.rs index ca1594190..c3837f439 100644 --- a/crates/lean_vm/src/tables/utils.rs +++ b/crates/lean_vm/src/tables/utils.rs @@ -2,6 +2,26 @@ use backend::*; use crate::ExtraDataForBuses; +/// Fingerprint constraint for deferred-claim buses (virtual address columns). +/// Emits one encoded constraint; no multiplicity assert (One-numerator is +/// reconstructed verifier-side). +pub(crate) fn eval_bus_data_only>>( + builder: &mut AB, + extra_data: &ExtraDataForBuses, + domainsep: usize, + data: &[AB::IF], +) { + let logup_alphas_eq_poly = extra_data.transmute_bus_data::(); + assert!(data.len() < logup_alphas_eq_poly.len()); + let encoded = logup_alphas_eq_poly + .iter() + .zip(data) + .map(|(c, d)| *c * *d) + .sum::() + + *logup_alphas_eq_poly.last().unwrap() * AB::F::from_usize(domainsep); + builder.assert_zero_ef(encoded); +} + pub(crate) fn eval_bus_virtual>>( builder: &mut AB, extra_data: &ExtraDataForBuses, diff --git a/crates/lean_vm/tests/poseidon_witness.rs b/crates/lean_vm/tests/poseidon_witness.rs new file mode 100644 index 000000000..fcd3d80ed --- /dev/null +++ b/crates/lean_vm/tests/poseidon_witness.rs @@ -0,0 +1,118 @@ +use backend::*; +use lean_vm::*; + +const WIDTH: usize = 8; + +fn xorshift(s: &mut u64) -> u64 { + *s ^= *s << 13; + *s ^= *s >> 7; + *s ^= *s << 17; + *s +} + +#[test] +fn deferred_aux_fill_equals_scalar() { + let mut seed = 0xDEADBEEF_u64; + let n = 1003; + let n_cols = num_cols_total_poseidon_8(); + let aux_start = POSEIDON_8_COL_ROUND_START; + let aux_end = aux_start + (num_cols_poseidon_8() - POSEIDON_8_COL_ROUND_START); + let mut trace: Vec> = (0..n_cols).map(|_| ArenaVec::new()).collect(); + for (c, col) in trace.iter_mut().enumerate() { + if (aux_start..aux_end).contains(&c) { + continue; + } + for _ in 0..n { + col.push(F::from_u64(xorshift(&mut seed))); + } + } + fill_trace_poseidon_8(&mut trace); + + assert!(trace.iter().all(|col| col.len() == n)); + for i in 0..n { + let input: [F; WIDTH] = std::array::from_fn(|j| trace[POSEIDON_8_COL_INPUT_START + j][i]); + let (aux, _) = compute_poseidon8_witness(input); + for (k, v) in aux.iter().enumerate() { + assert_eq!(trace[aux_start + k][i], *v, "aux[{k}] row {i}"); + } + } +} + +#[test] +#[ignore] +fn mds_witness_microbench() { + let mut seed = 1u64; + let inputs: Vec<[F; 8]> = (0..10_000) + .map(|_| std::array::from_fn(|_| F::from_u64(xorshift(&mut seed)))) + .collect(); + let t = std::time::Instant::now(); + let mut sink = F::ZERO; + for inp in &inputs { + let (_aux, out) = compute_poseidon8_witness(*inp); + sink += out[0]; + } + let dt = t.elapsed(); + println!( + "compute_poseidon8_witness: {:?} for 10k perms ({:.0} ns/perm), sink={sink:?}", + dt, + dt.as_nanos() as f64 / 10_000.0 + ); +} + +#[test] +#[ignore] +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +fn packed_witness_microbench() { + const BATCHES: usize = 16_384; + + let mut seed = 0xC0FFEE_u64; + let inputs: Vec<[[F; WIDTH]; WITNESS_LANES]> = (0..BATCHES) + .map(|_| std::array::from_fn(|_| std::array::from_fn(|_| F::from_u64(xorshift(&mut seed))))) + .collect(); + + let n_aux = num_cols_poseidon_8() - POSEIDON_8_COL_ROUND_START; + + for batch in inputs.iter().take(64) { + let (aux_p, out_p) = compute_poseidon8_witness_x8(batch); + for l in 0..WITNESS_LANES { + let (aux_s, out_s) = compute_poseidon8_witness(batch[l]); + for (k, a) in aux_s.iter().enumerate().take(n_aux) { + assert_eq!(aux_p[k].0[l], *a, "aux[{k}] lane {l}"); + } + for j in 0..WIDTH { + assert_eq!(out_p[j].0[l], out_s[j], "out[{j}] lane {l}"); + } + } + } + + let mut scalar_best = std::time::Duration::MAX; + let mut packed_best = std::time::Duration::MAX; + let mut sink = F::ZERO; + for _ in 0..3 { + let t = std::time::Instant::now(); + for batch in &inputs { + for inp in batch { + let (aux, out) = compute_poseidon8_witness(*inp); + sink += out[0] + aux[n_aux - 1]; + } + } + scalar_best = scalar_best.min(t.elapsed()); + + let t = std::time::Instant::now(); + for batch in &inputs { + let (aux, out) = compute_poseidon8_witness_x8(batch); + sink += out[0].0[0] + aux[n_aux - 1].0[WITNESS_LANES - 1]; + } + packed_best = packed_best.min(t.elapsed()); + } + + let n_perms = (BATCHES * WITNESS_LANES) as f64; + println!( + "scalar {:?} ({:.0} ns/perm) | packed {:?} ({:.0} ns/perm) | speedup {:.2}x | sink={sink:?}", + scalar_best, + scalar_best.as_nanos() as f64 / n_perms, + packed_best, + packed_best.as_nanos() as f64 / n_perms, + scalar_best.as_secs_f64() / packed_best.as_secs_f64(), + ); +} diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index ae40fc8b6..d2a0c1856 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -258,11 +258,17 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree let mut n_air_shift_columns = vec![]; let mut n_air_constraints = vec![]; let mut one_buses_all_cols = vec![]; + let mut one_buses_deferred = vec![]; + let mut one_buses_deferred_idx = vec![]; + let mut n_deferred_buses = vec![]; for table in ALL_TABLES { let mut table_domseps = vec![]; let mut table_data_cols = vec![]; let mut table_data_offsets = vec![]; let mut table_new_cols = vec![]; + let mut table_deferred = vec![]; + let mut table_deferred_idx = vec![]; + let mut n_deferred: usize = 0; let mut seen_cols: HashSet = HashSet::new(); for bus in table.bus_interactions() { if !matches!(bus.multiplicity, BusMultiplicity::One) { @@ -271,6 +277,19 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree let BusData::Constant(domsep) = bus.domainsep else { panic!("Multiplicity::One bus domsep must be a constant"); }; + // Deferred-claim bus: no per-column evals, one composite denominator. + if bus.deferred_claim { + table_domseps.push(domsep.to_string()); + table_data_cols.push("[]".to_string()); + table_data_offsets.push("[]".to_string()); + table_new_cols.push("[]".to_string()); + table_deferred.push("1".to_string()); + table_deferred_idx.push(n_deferred.to_string()); + n_deferred += 1; + continue; + } + table_deferred.push("0".to_string()); + table_deferred_idx.push("0".to_string()); // unused for non-deferred buses let mut data_cols = vec![]; let mut data_offsets = vec![]; let mut new_cols = vec![]; @@ -304,6 +323,9 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree one_buses_data_cols.push(format!("[{}]", table_data_cols.join(", "))); one_buses_data_offsets.push(format!("[{}]", table_data_offsets.join(", "))); one_buses_new_cols.push(format!("[{}]", table_new_cols.join(", "))); + one_buses_deferred.push(format!("[{}]", table_deferred.join(", "))); + one_buses_deferred_idx.push(format!("[{}]", table_deferred_idx.join(", "))); + n_deferred_buses.push(n_deferred.to_string()); let mut sorted_seen: Vec = seen_cols.iter().copied().collect(); sorted_seen.sort(); @@ -339,6 +361,29 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree "ONE_BUSES_NEW_COLS_PLACEHOLDER".to_string(), format!("[{}]", one_buses_new_cols.join(", ")), ); + // Deferred-claim bus metadata (flags, per-bus slot indices, counts). + replacements.insert( + "ONE_BUSES_DEFERRED_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_deferred.join(", ")), + ); + replacements.insert( + "ONE_BUSES_DEFERRED_IDX_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_deferred_idx.join(", ")), + ); + replacements.insert( + "N_DEFERRED_BUSES_PLACEHOLDER".to_string(), + format!("[{}]", n_deferred_buses.join(", ")), + ); + let max_deferred_buses = n_deferred_buses + .iter() + .map(|s| s.parse::().unwrap()) + .max() + .unwrap_or(0) + .max(1); + replacements.insert( + "MAX_DEFERRED_BUSES_PLACEHOLDER".to_string(), + max_deferred_buses.to_string(), + ); replacements.insert( "NUM_COLS_AIR_PLACEHOLDER".to_string(), format!("[{}]", num_cols_air.join(", ")), @@ -496,7 +541,9 @@ fn air_eval_in_zk_dsl(table: T) -> String where T::ExtraData: Default, { - let (constraints, bus_multiplicity, bus_data) = get_symbolic_constraints_and_bus_data_values::(&table); + // 4th element = deferred-claim virtual-column expressions (e.g. addr_a/b/c). + let (constraints, bus_multiplicity, bus_data, deferred_groups) = + get_symbolic_constraints_and_bus_data_values::(&table); // `bus_data`'s last entry is the domainsep (logup domain separation). let (bus_domainsep, bus_real_data) = bus_data.split_last().unwrap(); let mut ctx = AirCodegenCtx::new(); @@ -542,9 +589,75 @@ where ); res += "\n sum: Mut = add_extension_ret(bus_res, weighted_multiplicity)"; + // Deferred-claim fingerprints (alpha slots 2..2+n_deferred). + let mut virtual_expr_queue: Vec<_> = deferred_groups.iter().flatten().copied().collect(); + virtual_expr_queue.reverse(); // consume from the back via pop() + let n_committed = table.n_columns(); + let deferred_buses: Vec<_> = table + .bus_interactions() + .into_iter() + .filter(|b| b.deferred_claim) + .collect(); + let n_deferred = deferred_buses.len(); + for (i, bus) in deferred_buses.iter().enumerate() { + let BusData::Constant(domsep) = bus.domainsep else { + panic!("deferred-claim bus domainsep must be a constant"); + }; + res += &format!("\n dbuff_{} = Array(DIM * {})", i, bus.data.len()); + for (j, entry) in bus.data.iter().enumerate() { + let (col, offset) = match entry { + BusData::Column(c) => (*c, 0usize), + BusData::ColumnPlusConstant(c, k) => (*c, *k), + BusData::Constant(_) => panic!("constant data entries unsupported in deferred buses"), + }; + let src = if col >= n_committed { + // temporary/virtual column: next declared expression + let expr = virtual_expr_queue + .pop() + .expect("declared virtual-expression queue exhausted"); + eval_air_constraint(expr, None, &mut ctx, &mut res) + } else { + format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, col) + }; + if offset == 0 { + res += &format!("\n copy_ef({}, dbuff_{} + DIM * {})", src, i, j); + } else { + res += &format!( + "\n copy_ef(add_base_extension_ret({}, {}), dbuff_{} + DIM * {})", + offset, src, i, j + ); + } + } + res += &format!("\n denc_init_{} = Array(DIM)", i); + res += &format!( + "\n dot_product_ee(dbuff_{}, logup_beta_eq_poly, denc_init_{}, {})", + i, + i, + bus.data.len() + ); + let domsep_const = ctx.write_embedded_constant(domsep as u64, &mut res); + res += &format!( + "\n denc_{}: Mut = add_extension_ret(mul_extension_ret({}, logup_beta_eq_poly + {} * DIM), denc_init_{})", + i, + domsep_const, + (1 << LOG_MAX_BUS_WIDTH) - 1, + i + ); + res += &format!( + "\n sum = add_extension_ret(sum, mul_extension_ret(denc_{}, air_alpha_powers + {} * DIM))", + i, + 2 + i + ); + } + assert!( + virtual_expr_queue.is_empty(), + "unconsumed declared virtual expressions: deferred-bus data layout mismatch" + ); + res += "\n weighted_constraints = Array(DIM)"; res += &format!( - "\n dot_product_ee(air_alpha_powers + 2 * DIM, constraints_buf, weighted_constraints, {})", + "\n dot_product_ee(air_alpha_powers + {} * DIM, constraints_buf, weighted_constraints, {})", + 2 + n_deferred, n_constraints ); res += "\n sum = add_extension_ret(sum, weighted_constraints)"; diff --git a/crates/rec_aggregation/tests/deferred_claim_adversarial.rs b/crates/rec_aggregation/tests/deferred_claim_adversarial.rs new file mode 100644 index 000000000..d8cd6aa30 --- /dev/null +++ b/crates/rec_aggregation/tests/deferred_claim_adversarial.rs @@ -0,0 +1,79 @@ +//! Adversarial binding test for deferred-claim bus denominators: every transcript +//! scalar in the logup/deferred-claim region must be binding — flipping any one +//! byte must make verification reject. + +use backend::*; +use lean_compiler::*; +use lean_prover::{default_whir_config, prove_execution::prove_execution, verify_execution::verify_execution}; +use lean_vm::*; + +const PROGRAM: &str = r#" +def main(): + x: Mut = 0 + y: Mut = 1 + for i in unroll(0, 200): + z = x + y + x = y + y = z + assert y != 0 + return +"#; + +fn build_proof() -> (Bytecode, [F; PUBLIC_INPUT_LEN], Proof) { + let bytecode = compile_program_with_flags(&ProgramSource::Raw(PROGRAM.to_string()), CompilationFlags::default()); + let public_input = [F::ZERO; PUBLIC_INPUT_LEN]; + let witness = ExecutionWitness::default(); + let proof = prove_execution(&bytecode, &public_input, &witness, &default_whir_config(1), false).unwrap(); + (bytecode, public_input, proof.proof) +} + +#[test] +fn deferred_claims_are_binding() { + let (bytecode, public_input, proof) = build_proof(); + + // Honest proof verifies. + verify_execution(&bytecode, &public_input, proof.clone()).expect("honest proof must verify"); + + let bytes = postcard::to_allocvec(&proof).expect("serialize"); + + // The transcript Vec begins after its varint length header; each scalar is + // 8 bytes LE. Locate the header length by decoding the varint at offset 0. + let mut varint_len = 0usize; + let mut n_scalars = 0u64; + for (i, b) in bytes.iter().enumerate() { + n_scalars |= u64::from(b & 0x7f) << (7 * i); + varint_len += 1; + if b & 0x80 == 0 { + break; + } + } + let n_scalars = n_scalars as usize; + assert!(n_scalars > 256, "transcript unexpectedly short: {n_scalars}"); + + // Sweep the early transcript region (dims, commitment root, OOD, GKR rounds, + // logup claims incl. the deferred D̂_i, AIR rounds) plus a tail sample. + let mut tested = 0usize; + let mut rejected = 0usize; + let sweep: Vec = (0..n_scalars.min(1200)) + .step_by(13) + .chain([n_scalars - 1, n_scalars / 2]) + .collect(); + for scalar_idx in sweep { + let byte_idx = varint_len + scalar_idx * 8 + 3; // mid-limb byte: always value-changing + let mut tampered = bytes.clone(); + tampered[byte_idx] ^= 0x5a; + tested += 1; + match postcard::from_bytes::>(&tampered) { + Err(_) => rejected += 1, // deserialization caught it + Ok(p) => { + if verify_execution(&bytecode, &public_input, p).is_err() { + rejected += 1; + } else { + panic!("tampered transcript scalar {scalar_idx} ACCEPTED — binding failure"); + } + } + } + } + assert_eq!(tested, rejected); + println!("h9 adversarial sweep: {rejected}/{tested} tampered transcripts rejected"); +} diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 530014cb2..738f5ee77 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -23,6 +23,11 @@ ONE_BUSES_DATA_COLS = ONE_BUSES_DATA_COLS_PLACEHOLDER # [[[_; num_data]; num_buses]; N_TABLES] ONE_BUSES_DATA_OFFSETS = ONE_BUSES_DATA_OFFSETS_PLACEHOLDER # [[[_; num_data]; num_buses]; N_TABLES] ONE_BUSES_NEW_COLS = ONE_BUSES_NEW_COLS_PLACEHOLDER # [[[_; n_new]; num_buses]; N_TABLES] +# Deferred-claim buses (virtual data columns; one composite denominator on the wire): +ONE_BUSES_DEFERRED = ONE_BUSES_DEFERRED_PLACEHOLDER # [[0/1; num_buses]; N_TABLES] +ONE_BUSES_DEFERRED_IDX = ONE_BUSES_DEFERRED_IDX_PLACEHOLDER # [[slot or 0; num_buses]; N_TABLES] +N_DEFERRED_BUSES = N_DEFERRED_BUSES_PLACEHOLDER # [_; N_TABLES] +MAX_DEFERRED_BUSES = MAX_DEFERRED_BUSES_PLACEHOLDER # max(1, max over tables) NUM_COLS_AIR = NUM_COLS_AIR_PLACEHOLDER MAX_NUM_COLS_AIR = MAX_NUM_COLS_AIR_PLACEHOLDER # max(NUM_COLS_AIR[t] for t in 0..N_TABLES) @@ -222,6 +227,9 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): # Per-table data accumulators (indexed by table_index). bus_numerators_values = Array(N_TABLES * DIM) bus_denominators_values = Array(N_TABLES * DIM) + # Composite denominator evals of deferred-claim buses, consumed by the AIR + # initial_sum (slots alpha_offset+2+i, mirroring verify_execution.rs). + deferred_bus_denominators = Array(N_TABLES * MAX_DEFERRED_BUSES * DIM) pcs_inner_points = Array(N_TABLES) pcs_vals_logup = Array(N_TABLES * MAX_NUM_COLS_AIR) pcs_vals_air = Array(N_TABLES * MAX_NUM_COLS_AIR) @@ -255,34 +263,57 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): # Multiplicity::One buses (bytecode lookup + memory lookups). for one_bus_idx in unroll(0, len(ONE_BUSES_DOMSEPS[table_index])): - domsep = ONE_BUSES_DOMSEPS[table_index][one_bus_idx] - n_new = len(ONE_BUSES_NEW_COLS[table_index][one_bus_idx]) - n_data = len(ONE_BUSES_DATA_COLS[table_index][one_bus_idx]) - - fs, new_evals = fs_receive_ef_inlined(fs, n_new) - - for i in unroll(0, n_new): - new_col = ONE_BUSES_NEW_COLS[table_index][one_bus_idx][i] - pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + new_col] = new_evals + i * DIM - - data_evals = Array(n_data * DIM) - for i in unroll(0, n_data): - data_col = ONE_BUSES_DATA_COLS[table_index][one_bus_idx][i] - data_ofs = ONE_BUSES_DATA_OFFSETS[table_index][one_bus_idx][i] - src = pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + data_col] - if data_ofs == 0: - copy_ef(src, data_evals + i * DIM) - if data_ofs != 0: - copy_ef(add_base_extension_ret(data_ofs, src), data_evals + i * DIM) - - pref = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) - retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) - fingerp = fingerprint_n(domsep, data_evals, n_data, logup_beta_eq_poly) - retrieved_denominators_value = add_extension_ret( - retrieved_denominators_value, - mul_extension_ret(pref, sub_extension_ret(logup_gamma, fingerp)), - ) - offset += n_rows + # Deferred-claim bus: one composite denominator eval on the wire + # (gamma - fingerprint-hat at the GKR point); the One-numerator is the prefix + # itself (Push direction). The composite is grounded by the table's + # eval_bus_data_only constraint slot in the batched AIR sumcheck (initial_sum + # below). The bus still occupies its GKR segment: offset advances as usual. + if ONE_BUSES_DEFERRED[table_index][one_bus_idx] != 0: + fs, deferred_eval = fs_receive_ef_inlined(fs, 1) + pref_deferred = multilinear_location_prefix( + offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr + ) + retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref_deferred) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret(pref_deferred, deferred_eval), + ) + copy_ef( + deferred_eval, + deferred_bus_denominators + + (table_index * MAX_DEFERRED_BUSES + ONE_BUSES_DEFERRED_IDX[table_index][one_bus_idx]) + * DIM, + ) + offset += n_rows + if ONE_BUSES_DEFERRED[table_index][one_bus_idx] == 0: + domsep = ONE_BUSES_DOMSEPS[table_index][one_bus_idx] + n_new = len(ONE_BUSES_NEW_COLS[table_index][one_bus_idx]) + n_data = len(ONE_BUSES_DATA_COLS[table_index][one_bus_idx]) + + fs, new_evals = fs_receive_ef_inlined(fs, n_new) + + for i in unroll(0, n_new): + new_col = ONE_BUSES_NEW_COLS[table_index][one_bus_idx][i] + pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + new_col] = new_evals + i * DIM + + data_evals = Array(n_data * DIM) + for i in unroll(0, n_data): + data_col = ONE_BUSES_DATA_COLS[table_index][one_bus_idx][i] + data_ofs = ONE_BUSES_DATA_OFFSETS[table_index][one_bus_idx][i] + src = pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + data_col] + if data_ofs == 0: + copy_ef(src, data_evals + i * DIM) + if data_ofs != 0: + copy_ef(add_base_extension_ret(data_ofs, src), data_evals + i * DIM) + + pref = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) + retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) + fingerp = fingerprint_n(domsep, data_evals, n_data, logup_beta_eq_poly) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret(pref, sub_extension_ret(logup_gamma, fingerp)), + ) + offset += n_rows # Final logup adjustment (padding) retrieved_denominators_value = add_extension_ret( @@ -319,6 +350,20 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): sub_extension_ret(logup_gamma, bus_denominator_value), ), ) + # One fingerprint alpha-slot per deferred-claim bus, after the + # Column-bus pair (slots alpha_offset+2+i), mirroring verify_execution.rs. The + # AIR final check grounds each composite via the eval_bus_data_only constraints. + for i in unroll(0, N_DEFERRED_BUSES[table_index]): + bus_final_value = add_extension_ret( + bus_final_value, + mul_extension_ret( + air_alpha_powers + (alpha_offset + 2 + i) * DIM, + sub_extension_ret( + logup_gamma, + deferred_bus_denominators + (table_index * MAX_DEFERRED_BUSES + i) * DIM, + ), + ), + ) initial_sum = add_extension_ret(initial_sum, bus_final_value) n_max = log_max_table_height diff --git a/crates/sub_protocols/src/air_sumcheck.rs b/crates/sub_protocols/src/air_sumcheck.rs index c1db8ef0b..acb619b4a 100644 --- a/crates/sub_protocols/src/air_sumcheck.rs +++ b/crates/sub_protocols/src/air_sumcheck.rs @@ -31,6 +31,10 @@ use tracing::{info_span, instrument}; const ENDIANNESS_PIVOT_AIR: usize = 12; +fn c2_class_enabled(computation: &A) -> bool { + computation.c2_table_profitable() +} + pub trait OuterSumcheckSession>>: Debug { fn initial_n_vars(&self) -> usize; fn sum(&self) -> EF; @@ -41,6 +45,25 @@ pub trait OuterSumcheckSession>>: Debug { fn final_column_evals(&self) -> Vec; } +/// Per-pair constraint cache `T_i[x] = C(r_0..r_{i-1}, x)`. Lives packed +/// through phase 1, unpacked at the same boundary. The padding tail is NOT +/// materialized (padding rows are constant under folding). +#[derive(Debug)] +enum C2Store>> { + Packed(Vec>), + Unpacked(Vec), +} + +/// Fresh per-pair constraint-eval vectors cached from the current round's +/// pass, used at challenge time to extrapolate `T_{i+1}`. Seed round (local +/// round 0): nodes z = 0, 1, 2, .., d_z. Table rounds: nodes z = 2, .., d_z +/// (z=0 and z=1 come from `T_i[i0]`, `T_i[i1]`). +#[derive(Debug)] +enum C2Cache>> { + Packed(Vec>>), + Unpacked(Vec>), +} + #[derive(Debug)] pub struct AirSumcheckSession<'a, EF: ExtensionField>, A: Air> where @@ -58,6 +81,9 @@ where initial_n_vars: usize, constraints_eval_at_padding: EF, rounds_done: usize, + c2_enabled: bool, + c2_table: Option>, + c2_cache: Option>, } impl<'a, EF: ExtensionField>, A: Air> AirSumcheckSession<'a, EF, A> @@ -108,6 +134,7 @@ where _ => unreachable!(), }; + let c2_enabled = c2_class_enabled(&computation); Self { multilinears, eq_factor, @@ -119,8 +146,19 @@ where initial_n_vars, constraints_eval_at_padding, rounds_done: 0, + c2_enabled, + c2_table: None, + c2_cache: None, } } + + /// Test-only hook: force the fresh-eval (pre-C2) path so equality tests can + /// drive both strategies on identical inputs. Both paths are bit-identical + /// by construction; this switch only selects the slower reference one. + #[doc(hidden)] + pub fn set_c2_enabled_for_tests(&mut self, enabled: bool) { + self.c2_enabled = enabled; + } } impl<'a, EF, A> AirSumcheckSession<'a, EF, A> @@ -221,7 +259,8 @@ where } fn compute_bare_round_poly(&mut self) -> DensePolynomial { - let split_eq = SplitEq::new(&self.permuted_alphas(self.initial_n_vars - self.rounds_done - 1)); + let split_eq = info_span!("split_eq_new") + .in_scope(|| SplitEq::new(&self.permuted_alphas(self.initial_n_vars - self.rounds_done - 1))); let active_count_pairs = self.active_count_pairs(); let storage_shift = if self.in_phase_1() { packing_log_width::() @@ -237,14 +276,133 @@ where EF::ZERO }; - let p_evals_raw = compute_raw_poly( - &self.multilinears.by_ref(), - &self.computation, - &self.extra_data, - &split_eq, - self.folding_bit_packed(), - active_count_pairs, - ); + // C2: seed round caches all node vectors; table rounds reuse z=0 from cache. + let fold_bit = self.folding_bit_packed(); + let is_seed = self.c2_table.is_none(); + #[derive(PartialEq)] + enum C2RoundMode { + SeedPacked, + TablePacked, + TableUnpacked, + Fallback, + } + let mode = if !self.c2_enabled { + C2RoundMode::Fallback + } else { + match (&self.multilinears.by_ref(), &self.c2_table) { + (MleGroupRef::BasePacked(_), None) if self.rounds_done == 0 && self.in_phase_1() => { + C2RoundMode::SeedPacked + } + (MleGroupRef::ExtensionPacked(_), Some(C2Store::Packed(_))) => C2RoundMode::TablePacked, + (MleGroupRef::Extension(_), Some(C2Store::Unpacked(_))) => C2RoundMode::TableUnpacked, + _ => C2RoundMode::Fallback, + } + }; + if mode == C2RoundMode::Fallback && self.c2_enabled { + // Non-conforming shape: disable C2 permanently for this session. + self.c2_enabled = false; + self.c2_table = None; + } + let (p_evals_raw, new_cache): (Vec, Option>) = match mode { + C2RoundMode::SeedPacked => { + let MleGroupRef::BasePacked(cols) = self.multilinears.by_ref() else { + unreachable!() + }; + // Bus-only evals for on-row nodes (z=0, z=1). + let eval_bus_01 = |a: &A, point: &[PFPacking], xd: &A::ExtraData| -> EFPacking { + let n_cols = a.n_columns(); + let mut folder = ConstraintFolderPacked::new(&point[..n_cols], &point[n_cols..], xd); + a.eval_bus_only(&mut folder, xd); + folder.accumulator + }; + let (accs, cache) = c2_pass::, EFPacking, _, _, _>( + &cols, + |j| split_eq.get_packed(j), + &self.computation, + &self.extra_data, + fold_bit, + active_count_pairs, + A::eval_packed_base, + eval_bus_01, + None, + ); + let accs = accs.into_iter().map(unpack_sum_packed::).collect(); + (accs, Some(C2Cache::Packed(cache))) + } + C2RoundMode::TablePacked => { + let MleGroupRef::ExtensionPacked(cols) = self.multilinears.by_ref() else { + unreachable!() + }; + let Some(C2Store::Packed(table)) = &self.c2_table else { + unreachable!() + }; + let (accs, cache) = c2_pass::, EFPacking, _, _, _>( + &cols, + |j| split_eq.get_packed(j), + &self.computation, + &self.extra_data, + fold_bit, + active_count_pairs, + A::eval_packed_extension, + A::eval_packed_extension, + Some(table), + ); + let accs = accs.into_iter().map(unpack_sum_packed::).collect(); + (accs, Some(C2Cache::Packed(cache))) + } + C2RoundMode::TableUnpacked => { + let MleGroupRef::Extension(cols) = self.multilinears.by_ref() else { + unreachable!() + }; + let Some(C2Store::Unpacked(table)) = &self.c2_table else { + unreachable!() + }; + let (accs, cache) = c2_pass::( + &cols, + |j| split_eq.get_unpacked(j), + &self.computation, + &self.extra_data, + fold_bit, + active_count_pairs, + A::eval_extension, + A::eval_extension, + Some(table), + ); + (accs, Some(C2Cache::Unpacked(cache))) + } + C2RoundMode::Fallback => { + let fresh = compute_raw_poly( + &self.multilinears.by_ref(), + &self.computation, + &self.extra_data, + &split_eq, + fold_bit, + active_count_pairs, + ); + (fresh, None) + } + }; + + // Debug: verify C2 path reproduces the fresh-eval accumulators exactly. + #[cfg(debug_assertions)] + if new_cache.is_some() { + let fresh = compute_raw_poly( + &self.multilinears.by_ref(), + &self.computation, + &self.extra_data, + &split_eq, + fold_bit, + active_count_pairs, + ); + debug_assert_eq!( + p_evals_raw, fresh, + "C2 dual-compute mismatch (seed={is_seed}, round={})", + self.rounds_done + ); + } + let _ = is_seed; + self.c2_cache = new_cache; + let mut p_evals: Vec = p_evals_raw .into_iter() .map(|v| (v + padding_contribution) * self.missing_mul_factor) @@ -272,15 +430,54 @@ where let was_in_phase_1 = self.in_phase_1(); let fold_bit = self.folding_bit_packed(); + // C2: extrapolate `T_{i+1}[j] = C_pair_j(challenge)` from the cached + // fresh node vectors (+ `T_i` lookups in table rounds) BEFORE the fold + // mutates the round geometry. Same pair enumeration as the round pass. + if self.c2_enabled { + let new_pairs = self.active_count_pairs(); + let weights = lagrange_weights_at(self.computation.degree_z(), challenge); + let next = match (self.c2_table.take(), self.c2_cache.take()) { + (None, Some(C2Cache::Packed(cache))) => Some(C2Store::Packed(c2_update_packed( + None, &cache, &weights, fold_bit, new_pairs, + ))), + (Some(C2Store::Packed(table)), Some(C2Cache::Packed(cache))) => Some(C2Store::Packed( + c2_update_packed(Some(&table), &cache, &weights, fold_bit, new_pairs), + )), + (Some(C2Store::Unpacked(table)), Some(C2Cache::Unpacked(cache))) => { + Some(C2Store::Unpacked(c2_update_unpacked( + &table, + &cache, + &weights, + fold_bit, + new_pairs, + self.constraints_eval_at_padding, + ))) + } + _ => None, + }; + match next { + Some(t) => self.c2_table = Some(t), + None => { + self.c2_enabled = false; + self.c2_table = None; + } + } + } + self.multilinears = self.multilinears.by_ref().fold_at_bit(challenge, fold_bit).into(); self.current_unpadded_len = self.current_unpadded_len.div_ceil(2); self.rounds_done += 1; self.eq_factor.pop(); - // Phase 1 → phase 2: unpack + // Phase 1 → phase 2: unpack (columns and the C2 table together — the + // table must track the columns' storage mode exactly) if was_in_phase_1 && !self.in_phase_1() { self.multilinears = self.multilinears.by_ref().unpack().as_owned_or_clone().into(); + if let Some(C2Store::Packed(t)) = &self.c2_table { + let unpacked: Vec = EFPacking::::to_ext_iter(t.iter().copied()).collect(); + self.c2_table = Some(C2Store::Unpacked(unpacked)); + } } } @@ -385,7 +582,10 @@ where GetEq: Fn(usize) -> EFT + Sync + Send, UnpackSum: Fn(EFT) -> EF + Sync + Send, { - let degree = computation.degree(); + // Fresh-eval count per pair: sized by the TRUE constraint degree (`degree_z`, + // see the Air trait doc) — the bare-poly interpolation and the wire format + // remain sized by the declared degree. + let degree = computation.degree_z(); let n_cols = cols.len(); let stride = 1usize << fold_bit; let lo_mask = stride - 1; @@ -431,6 +631,221 @@ where acc.into_iter().map(unpack_sum).collect() } +#[inline(always)] +fn unpack_sum_packed>>(s: EFPacking) -> EF { + EFPacking::::to_ext_iter([s]).sum::() +} + +/// C2 round pass. Two modes: +/// - `table = None` (seed, local round 0): fresh evals at z = 0, 1, 2, .., d_z; +/// ALL d_z+1 per-pair vectors cached; message accumulators as today +/// (z=0 weighted into `acc[0]`, z=2.. into `acc[1..]`; z=1 cached only). +/// - `table = Some(T_i)`: `acc[0] = Σ eq·T_i[i0]` (no constraint eval); fresh +/// evals at z = 2..d_z only, cached for the challenge-time extrapolation. +/// +/// Returns `(message accumulators, cached node vectors)`. Cached vectors are +/// indexed by `new_j` (same pair enumeration as the accumulators) and written +/// exactly once each via the raw-ptr-in-map_reduce precedent +/// (sc_computation.rs `sumcheck_fold_and_compute_core`). +#[allow(clippy::too_many_arguments)] +fn c2_pass( + cols: &[&[IF]], + get_split_eq: GetEq, + computation: &A, + extra_data: &A::ExtraData, + fold_bit: usize, + active_count_pairs: usize, + eval_fn: EvalFn, + eval_fn_01: EvalFn01, + table: Option<&[EFT]>, +) -> (Vec, Vec>) +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, + IF: Copy + Send + Sync + Sub + AddAssign + PrimeCharacteristicRing, + EFT: Copy + Send + Sync + Add + AddAssign + Mul + PrimeCharacteristicRing, + GetEq: Fn(usize) -> EFT + Sync + Send, + EvalFn: Fn(&A, &[IF], &A::ExtraData) -> EFT + Sync + Send, + EvalFn01: Fn(&A, &[IF], &A::ExtraData) -> EFT + Sync + Send, +{ + let degree = computation.degree_z(); + let n_cols = cols.len(); + let stride = 1usize << fold_bit; + let lo_mask = stride - 1; + let is_seed = table.is_none(); + let n_cached = if is_seed { degree + 1 } else { degree - 1 }; + + let cache: Vec> = (0..n_cached) + .map(|_| unsafe { uninitialized_vec::(active_count_pairs) }) + .collect(); + + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || (Vec::::with_capacity(n_cols), Vec::::with_capacity(n_cols)), + || vec![EFT::ZERO; degree], + |(point, diff), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } + // SAFETY: each `new_j` is visited exactly once across all tasks; + // every cache slot is written exactly once before being read. + let write_cache = |vec_idx: usize, v: EFT| unsafe { + let ptr = cache[vec_idx].as_ptr() as *mut EFT; + *ptr.add(new_j) = v; + }; + match table { + None => { + // Seed: z = 0 (acc + cache), z = 1 (cache only), z = 2.. (acc + cache). + // z=0 / z=1 are evaluations on actual storage rows -> the + // z=0,1 use bus-only eval; z=2.. need full eval. + let v0 = eval_fn_01(computation, point, extra_data); + write_cache(0, v0); + acc[0] += v0 * partial_eq; + for k in 0..n_cols { + point[k] += diff[k]; + } + let v1 = eval_fn_01(computation, point, extra_data); + write_cache(1, v1); + for (zi, acc_z) in acc[1..].iter_mut().enumerate() { + for k in 0..n_cols { + point[k] += diff[k]; + } + let v = eval_fn(computation, point, extra_data); + write_cache(2 + zi, v); + *acc_z += v * partial_eq; + } + } + Some(t) => { + // Table: z = 0 from T_i (i0 is always inside the active + // prefix — see the padding analysis in the session docs). + acc[0] += t[i0] * partial_eq; + for k in 0..n_cols { + point[k] += diff[k]; + } + for (zi, acc_z) in acc[1..].iter_mut().enumerate() { + for k in 0..n_cols { + point[k] += diff[k]; + } + let v = eval_fn(computation, point, extra_data); + write_cache(zi, v); + *acc_z += v * partial_eq; + } + } + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); + + (acc, cache) +} + +/// Lagrange weights `w_k(r)` over the node set `{0, 1, .., d}`: +/// `p(r) = Σ_k w_k(r)·p(k)` for any univariate `p` of degree ≤ d. Exact field +/// arithmetic — the same interpolation the message path performs, evaluated. +fn lagrange_weights_at(d: usize, r: EF) -> Vec { + (0..=d) + .map(|k| { + let mut num = EF::ONE; + let mut den = EF::ONE; + for m in 0..=d { + if m != k { + num *= r - EF::from_usize(m); + den *= EF::from_usize(k) - EF::from_usize(m); + } + } + num * den.inverse() + }) + .collect() +} + +/// `T_{i+1}[new_j] = w_0·T_i[i0] + w_1·T_i[i1] + Σ_z w_z·v_z[new_j]` +/// (packed phase: chunk alignment guarantees `i0`, `i1` stay inside the +/// active prefix — no padding reads; debug-asserted). +fn c2_update_packed>>( + table: Option<&[EFPacking]>, + cache: &[Vec>], + weights: &[EF], + fold_bit: usize, + new_pairs: usize, +) -> Vec> { + let w_packed: Vec> = weights.iter().map(|&w| EFPacking::::from(w)).collect(); + let stride = 1usize << fold_bit; + let lo_mask = stride - 1; + let mut out = unsafe { uninitialized_vec::>(new_pairs) }; + const CHUNK_P: usize = 1 << 10; + parallel::par_chunks_mut(&mut out, CHUNK_P, |chunk_idx, chunk| { + for (off, slot) in chunk.iter_mut().enumerate() { + let new_j = chunk_idx * CHUNK_P + off; + let mut v = match table { + Some(t) => { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + debug_assert!(i1 < t.len(), "packed C2 update read past the active prefix"); + w_packed[0] * t[i0] + w_packed[1] * t[i1] + } + None => w_packed[0] * cache[0][new_j] + w_packed[1] * cache[1][new_j], + }; + let fresh_off = if table.is_some() { 0 } else { 2 }; + for (zi, w) in w_packed[2..].iter().enumerate() { + v += *w * cache[fresh_off + zi][new_j]; + } + *slot = v; + } + }); + out +} + +/// Unpacked-phase variant; a straddle pair may read `T_i[i1]` at/after the +/// active boundary, which is the (round-invariant) padding constant. +fn c2_update_unpacked>>( + table: &[EF], + cache: &[Vec], + weights: &[EF], + fold_bit: usize, + new_pairs: usize, + pad_value: EF, +) -> Vec { + let stride = 1usize << fold_bit; + let lo_mask = stride - 1; + let mut out = unsafe { uninitialized_vec::(new_pairs) }; + const CHUNK_U: usize = 1 << 12; + parallel::par_chunks_mut(&mut out, CHUNK_U, |chunk_idx, chunk| { + for (off, slot) in chunk.iter_mut().enumerate() { + let new_j = chunk_idx * CHUNK_U + off; + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let t0 = table[i0]; + let t1 = if i1 < table.len() { table[i1] } else { pad_value }; + let mut v = weights[0] * t0 + weights[1] * t1; + for (zi, w) in weights[2..].iter().enumerate() { + v += *w * cache[zi][new_j]; + } + *slot = v; + } + }); + out +} + #[instrument(skip_all)] pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( prover_state: &mut impl FSProver, @@ -443,6 +858,7 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( let mut k: Vec = vec![EF::ONE; sessions.len()]; for round in 0..n_rounds { + let round_span = info_span!("air_round", round).entered(); let mut combined_coeffs = EF::zero_vec(max_full_degree + 1); let mut bare_polys: Vec>> = vec![None; sessions.len()]; @@ -451,7 +867,7 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( if round < join_round { combined_coeffs[1] += k[idx] * session.sum(); } else { - let bare_poly = session.compute_bare_round_poly(); + let bare_poly = info_span!("air_poly", session = idx).in_scope(|| session.compute_bare_round_poly()); let full_coeffs = expand_bare_to_full(&bare_poly.coeffs, session.eq_alpha()); for (i, &c) in full_coeffs.iter().enumerate() { combined_coeffs[i] += k[idx] * c; @@ -469,9 +885,10 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( if round < join_round { k[idx] *= challenge; } else if let Some(bare_poly) = &bare_polys[idx] { - session.process_challenge(challenge, bare_poly); + info_span!("air_fold", session = idx).in_scope(|| session.process_challenge(challenge, bare_poly)); } } + drop(round_span); } MultilinearPoint(challenges) diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index f8ebda98e..5d58978a6 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -14,7 +14,8 @@ pub struct GenericLogupStatements { pub bytecode_and_acc_point: MultilinearPoint, pub value_bytecode_acc: EF, pub bus_numerators_values: BTreeMap, - pub bus_denominators_values: BTreeMap, + /// Index 0 = Column-bus denominator; indices 1.. = deferred-claim bus denominators. + pub bus_denominators_values: BTreeMap>, pub gkr_point: Vec, pub columns_values: BTreeMap>, // Used in recursion @@ -285,7 +286,20 @@ pub fn prove_generic_logup( let eval_on_data = c - finger_print(resolve_ef(bus.domainsep), &data_evals, alphas_eq_poly); prover_state.add_extension_scalar(eval_on_data); bus_numerators_values.insert(table, eval_on_multiplicity); - bus_denominators_values.insert(table, eval_on_data); + bus_denominators_values + .entry(table) + .or_insert_with(Vec::new) + .insert(0, eval_on_data); + } + // Deferred-claim bus: one composite denominator eval on the wire. + BusMultiplicity::One if bus.deferred_claim => { + let data_evals: Vec = bus.data.iter().map(|e| resolve_ef(*e)).collect(); + let eval_on_data = c - finger_print(resolve_ef(bus.domainsep), &data_evals, alphas_eq_poly); + prover_state.add_extension_scalar(eval_on_data); + bus_denominators_values + .entry(table) + .or_insert_with(Vec::new) + .push(eval_on_data); } BusMultiplicity::One => { // Skip columns already in table_values: memory-lookup groups share @@ -430,7 +444,20 @@ pub fn verify_generic_logup( retrieved_numerators_value += pref * eval_on_multiplicity; retrieved_denominators_value += pref * eval_on_data; bus_numerators_values.insert(table, eval_on_multiplicity); - bus_denominators_values.insert(table, eval_on_data); + bus_denominators_values + .entry(table) + .or_insert_with(Vec::new) + .insert(0, eval_on_data); + } + // Deferred-claim bus: read composite denominator eval from transcript. + BusMultiplicity::One if bus.deferred_claim => { + let eval_on_data = verifier_state.next_extension_scalar()?; + retrieved_numerators_value += pref * bus.direction.to_field_flag(); + retrieved_denominators_value += pref * eval_on_data; + bus_denominators_values + .entry(table) + .or_insert_with(Vec::new) + .push(eval_on_data); } BusMultiplicity::One => { let n_col_entries = bus diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 07fcabca6..682966407 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -209,7 +209,12 @@ pub fn total_whir_statements() -> usize { .iter() .map(|table| { let mut seen_cols = std::collections::HashSet::::new(); - for bus in table.bus_interactions().iter().filter(|b| b.is_memory_lookup()) { + // Skip deferred-claim buses (no per-column PCS statements). + for bus in table + .bus_interactions() + .iter() + .filter(|b| b.is_memory_lookup() && !b.deferred_claim) + { for entry in &bus.data { if let Some(col) = entry.column() { seen_cols.insert(col); diff --git a/crates/sub_protocols/tests/c2_bus_seed.rs b/crates/sub_protocols/tests/c2_bus_seed.rs new file mode 100644 index 000000000..57bc039a8 --- /dev/null +++ b/crates/sub_protocols/tests/c2_bus_seed.rs @@ -0,0 +1,80 @@ +//! C2 bus-only seeding test: on every valid row, the bus-only accumulator +//! must equal the full constraint accumulator (non-bus gates vanish row-wise). + +use backend::*; +use lean_vm::{ + EF, ExtraDataForBuses, F, HALF_DIGEST_LEN, POSEIDON_8_COL_ADDR_LEFT_HI, POSEIDON_8_COL_ADDR_LEFT_LO, + POSEIDON_8_COL_FLAG_OUT4, POSEIDON_8_COL_INPUT_START, POSEIDON_8_COL_MULTIPLICITY, POSEIDON_8_COL_OUT_LO, + POSEIDON_8_COL_ROUND_START, Poseidon8Precompile, compute_poseidon8_witness, fill_trace_poseidon_8, + num_cols_poseidon_8, +}; +use rand::{RngExt, SeedableRng, rngs::StdRng}; + +#[test] +fn busonly_equals_full_on_valid_rows() { + let log_n_rows = 10usize; + let n_rows = 1usize << log_n_rows; + let non_padded = 900usize; + let mut rng = StdRng::seed_from_u64(99); + + // Witness-consistent trace, production padding (rows >= non_padded repeat + // the last real row). + let n_cols = num_cols_poseidon_8(); + let mut trace: Vec> = (0..n_cols).map(|_| ArenaVec::filled(F::ZERO, n_rows)).collect(); + for t in trace.iter_mut().skip(POSEIDON_8_COL_INPUT_START).take(WIDTH) { + *t = ArenaVec::from_iter((0..n_rows).map(|_| rng.random())); + } + trace[POSEIDON_8_COL_MULTIPLICITY] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_8_COL_FLAG_OUT4] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_8_COL_ADDR_LEFT_LO] = ArenaVec::filled(F::ZERO, n_rows); + trace[POSEIDON_8_COL_ADDR_LEFT_HI] = ArenaVec::filled(F::from_usize(HALF_DIGEST_LEN), n_rows); + #[allow(clippy::needless_range_loop)] + for row in 0..non_padded { + let input: [F; WIDTH] = std::array::from_fn(|i| trace[POSEIDON_8_COL_INPUT_START + i][row]); + let (aux, perm_state) = compute_poseidon8_witness(input); + for i in 0..WIDTH / 2 { + trace[POSEIDON_8_COL_OUT_LO + i][row] = perm_state[i] + input[i]; + } + for (i, v) in aux.iter().enumerate() { + trace[POSEIDON_8_COL_ROUND_START + i][row] = *v; + } + } + fill_trace_poseidon_8(&mut trace); + let last = non_padded - 1; + for col in trace.iter_mut() { + let v = col[last]; + for row in non_padded..n_rows { + col[row] = v; + } + } + + // Production-shaped bus data: 2^LOG_MAX_BUS_WIDTH eq-table entries. + let air = Poseidon8Precompile::; + let alpha: EF = rng.random(); + let alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints()); + let bus_point: Vec = (0..4).map(|_| rng.random()).collect(); + let logup_alphas_eq: Vec = eval_eq(&bus_point).to_vec(); + let extra = ExtraDataForBuses::new(&logup_alphas_eq, alpha_powers); + + let mut mismatches = 0usize; + for row in 0..n_rows { + let point: Vec = trace.iter().map(|c| EF::from(c[row])).collect(); + let full = { + let mut folder = ConstraintFolder::new(&point[..air.n_columns()], &point[air.n_columns()..], &extra); + Air::eval(&air, &mut folder, &extra); + folder.accumulator + }; + let bus_only = { + let mut folder = ConstraintFolder::new(&point[..air.n_columns()], &point[air.n_columns()..], &extra); + Air::eval_bus_only(&air, &mut folder, &extra); + folder.accumulator + }; + if full != bus_only { + mismatches += 1; + if mismatches <= 3 { + eprintln!("row {row}: full = {full:?}, bus_only = {bus_only:?}"); + } + } + } + assert_eq!(mismatches, 0, "bus-only != full on {mismatches} of {n_rows} rows"); +} diff --git a/crates/sub_protocols/tests/c2_table.rs b/crates/sub_protocols/tests/c2_table.rs new file mode 100644 index 000000000..8b969d53d --- /dev/null +++ b/crates/sub_protocols/tests/c2_table.rs @@ -0,0 +1,118 @@ +//! C2 per-pair table equality test: verifies the C2 path produces identical +//! round polynomials vs the fresh-eval path, including padding rows and the +//! phase-1 → phase-2 transition. The bare round polynomials must be identical +//! elements every round (this is what makes C2 transcript-invariant), and the +//! final column evaluations must agree. +//! +//! Run in the dev profile to also exercise the in-session dual-compute +//! debug_assert (C2 accumulators vs fresh accumulators, every round). + +use backend::*; +use lean_vm::{ + EF, ExtraDataForBuses, F, HALF_DIGEST_LEN, POSEIDON_8_COL_ADDR_LEFT_HI, POSEIDON_8_COL_ADDR_LEFT_LO, + POSEIDON_8_COL_FLAG_OUT4, POSEIDON_8_COL_INPUT_START, POSEIDON_8_COL_MULTIPLICITY, POSEIDON_8_COL_OUT_LO, + POSEIDON_8_COL_ROUND_START, Poseidon8Precompile, compute_poseidon8_witness, fill_trace_poseidon_8, + num_cols_poseidon_8, +}; +use rand::{RngExt, SeedableRng, rngs::StdRng}; +use sub_protocols::{AirSumcheckSession, OuterSumcheckSession}; + +/// Valid rows for `[0, non_padded)`; rows `[non_padded, 2^log_n_rows)` are +/// copies of row `non_padded - 1` (the production padding convention the C2 +/// implicit tail relies on). +fn build_padded_poseidon_trace(log_n_rows: usize, non_padded: usize, rng: &mut StdRng) -> Vec> { + let n_rows = 1 << log_n_rows; + assert!(non_padded >= 1 && non_padded <= n_rows); + let n_cols = num_cols_poseidon_8(); + let mut trace: Vec> = (0..n_cols).map(|_| ArenaVec::filled(F::ZERO, n_rows)).collect(); + for t in trace.iter_mut().skip(POSEIDON_8_COL_INPUT_START).take(WIDTH) { + *t = ArenaVec::from_iter((0..n_rows).map(|_| rng.random())); + } + trace[POSEIDON_8_COL_MULTIPLICITY] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_8_COL_FLAG_OUT4] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_8_COL_ADDR_LEFT_LO] = ArenaVec::filled(F::ZERO, n_rows); + trace[POSEIDON_8_COL_ADDR_LEFT_HI] = ArenaVec::filled(F::from_usize(HALF_DIGEST_LEN), n_rows); + #[allow(clippy::needless_range_loop)] + for row in 0..non_padded { + let input: [F; WIDTH] = std::array::from_fn(|i| trace[POSEIDON_8_COL_INPUT_START + i][row]); + let (aux, perm_state) = compute_poseidon8_witness(input); + for i in 0..WIDTH / 2 { + trace[POSEIDON_8_COL_OUT_LO + i][row] = perm_state[i] + input[i]; + } + for (i, v) in aux.iter().enumerate() { + trace[POSEIDON_8_COL_ROUND_START + i][row] = *v; + } + } + fill_trace_poseidon_8(&mut trace); + // Production padding: every row past the last real one repeats it. + let last = non_padded - 1; + for col in trace.iter_mut() { + let v = col[last]; + for row in non_padded..n_rows { + col[row] = v; + } + } + trace +} + +fn run_equality_case(log_n_rows: usize, non_padded: usize, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let trace = build_padded_poseidon_trace(log_n_rows, non_padded, &mut rng); + + let air = Poseidon8Precompile::; + let alpha: EF = rng.random(); + let alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints()); + + let eq_factor: Vec = (0..log_n_rows).map(|_| rng.random()).collect(); + let sum: EF = rng.random(); // shared by both sessions; bare polys stay comparable + + let make_session = |c2: bool| { + let cols: Vec<&[F]> = trace.iter().map(|c| c.as_slice()).collect(); + let packed = MleGroupRef::::Base(cols).pack(); + let extra = ExtraDataForBuses::new(&[], alpha_powers.clone()); + let mut s = AirSumcheckSession::new(packed, eq_factor.clone(), sum, air, extra, non_padded); + s.set_c2_enabled_for_tests(c2); + s + }; + let mut s_c2 = make_session(true); + let mut s_ref = make_session(false); + + for round in 0..log_n_rows { + let p_c2 = s_c2.compute_bare_round_poly(); + let p_ref = s_ref.compute_bare_round_poly(); + assert_eq!( + p_c2.coeffs, p_ref.coeffs, + "bare round-poly mismatch at round {round} (n={log_n_rows}, non_padded={non_padded})" + ); + let challenge: EF = rng.random(); + s_c2.process_challenge(challenge, &p_c2); + s_ref.process_challenge(challenge, &p_ref); + } + assert_eq!( + s_c2.final_column_evals(), + s_ref.final_column_evals(), + "final column evals mismatch (n={log_n_rows}, non_padded={non_padded})" + ); +} + +/// Padding + straddles: 9000 of 2^14 rows -> padded to 12288 = 3 * 2^12; the +/// unpacked rounds walk 24 -> 12 -> 6 -> 3 -> ceil(3/2)=2 -> 1, exercising the +/// odd straddle pair (i1 lands on the implicit padding tail). +#[test] +fn c2_equality_padded_straddle() { + run_equality_case(14, 9000, 7); +} + +/// Half-full table: 4000 of 2^13 -> padded to 4096 = one chunk; analytic +/// padding contribution active from round 0. +#[test] +fn c2_equality_half_table() { + run_equality_case(13, 4000, 11); +} + +/// Full table (no padding at all): the active region equals the iteration +/// domain in every round. +#[test] +fn c2_equality_full_table() { + run_equality_case(13, 1 << 13, 23); +} diff --git a/crates/sub_protocols/tests/degree_census.rs b/crates/sub_protocols/tests/degree_census.rs new file mode 100644 index 000000000..ada2e8075 --- /dev/null +++ b/crates/sub_protocols/tests/degree_census.rs @@ -0,0 +1,106 @@ +//! Degree census: verifies that `degree_z() = 7` is correct for Poseidon8 +//! (the constraint composition has true degree ≤ 7, so z=8 eval is redundant). + +use backend::*; +use lean_vm::{ + EF, ExtraDataForBuses, F, HALF_DIGEST_LEN, POSEIDON_8_COL_ADDR_LEFT_HI, POSEIDON_8_COL_ADDR_LEFT_LO, + POSEIDON_8_COL_FLAG_OUT4, POSEIDON_8_COL_INPUT_START, POSEIDON_8_COL_MULTIPLICITY, POSEIDON_8_COL_OUT_LO, + POSEIDON_8_COL_ROUND_START, Poseidon8Precompile, compute_poseidon8_witness, fill_trace_poseidon_8, + num_cols_poseidon_8, +}; +use rand::{RngExt, SeedableRng, rngs::StdRng}; + +fn build_real_poseidon_trace(log_n_rows: usize, rng: &mut StdRng) -> Vec> { + let n_rows = 1 << log_n_rows; + let n_cols = num_cols_poseidon_8(); + let mut trace: Vec> = (0..n_cols).map(|_| ArenaVec::filled(F::ZERO, n_rows)).collect(); + for t in trace.iter_mut().skip(POSEIDON_8_COL_INPUT_START).take(WIDTH) { + *t = ArenaVec::from_iter((0..n_rows).map(|_| rng.random())); + } + trace[POSEIDON_8_COL_MULTIPLICITY] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_8_COL_FLAG_OUT4] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_8_COL_ADDR_LEFT_LO] = ArenaVec::filled(F::ZERO, n_rows); + trace[POSEIDON_8_COL_ADDR_LEFT_HI] = ArenaVec::filled(F::from_usize(HALF_DIGEST_LEN), n_rows); + #[allow(clippy::needless_range_loop)] + for row in 0..n_rows { + let input: [F; WIDTH] = std::array::from_fn(|i| trace[POSEIDON_8_COL_INPUT_START + i][row]); + let (aux, perm_state) = compute_poseidon8_witness(input); + for i in 0..WIDTH / 2 { + trace[POSEIDON_8_COL_OUT_LO + i][row] = perm_state[i] + input[i]; + } + for (i, v) in aux.iter().enumerate() { + trace[POSEIDON_8_COL_ROUND_START + i][row] = *v; + } + } + fill_trace_poseidon_8(&mut trace); + trace +} + +#[test] +fn poseidon_fold_line_degree_census() { + let log_n_rows = 8usize; + let n_rows = 1usize << log_n_rows; + let mut rng = StdRng::seed_from_u64(42); + let trace = build_real_poseidon_trace(log_n_rows, &mut rng); + + let air = Poseidon8Precompile::; + assert_eq!(air.degree_air(), 8); + assert_eq!(air.degree_z(), 7); + + let alpha: EF = rng.random(); + let alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints()); + let extra_data = ExtraDataForBuses::new(&[], alpha_powers); + + for trial in 0..32 { + // Random fold pair (i0, i1) — the fold line is `point + z*diff`. + let i0 = (rng.random::() as usize) % n_rows; + let mut i1 = (rng.random::() as usize) % n_rows; + if i1 == i0 { + i1 = (i1 + 1) % n_rows; + } + let lo: Vec = trace.iter().map(|c| EF::from(c[i0])).collect(); + let diff: Vec = trace.iter().map(|c| EF::from(c[i1]) - EF::from(c[i0])).collect(); + + // Evaluate s(z) = C(lo + z*diff) at z = 0..=8 (the OLD point set). + let evals9: Vec<(F, EF)> = (0..=8u64) + .map(|z| { + let zf = EF::from_u64(z); + let point: Vec = lo.iter().zip(&diff).map(|(l, d)| *l + zf * *d).collect(); + ( + F::from_u64(z), + as SumcheckComputation>::eval_extension(&air, &point, &extra_data), + ) + }) + .collect(); + + // (a) Census: the 9-point interpolant's top coefficient is EXACTLY zero. + let poly9 = DensePolynomial::lagrange_interpolation(&evals9).unwrap(); + assert!(poly9.coeffs.len() <= 9, "trial {trial}: unexpected degree blowup"); + if poly9.coeffs.len() == 9 { + assert_eq!( + poly9.coeffs[8], + EF::ZERO, + "trial {trial}: poseidon fold-line composition has a non-zero z^8 coefficient — census violated" + ); + } + + // (b) Old/new bare-poly equality: the degree-≤7 interpolant through the + // NEW point set z = 0..=7 is the SAME polynomial (so the prover's + // message coefficients are bit-identical). + let poly8 = DensePolynomial::lagrange_interpolation(&evals9[..8]).unwrap(); + for k in 0..8 { + let a = poly8.coeffs.get(k).copied().unwrap_or(EF::ZERO); + let b = poly9.coeffs.get(k).copied().unwrap_or(EF::ZERO); + assert_eq!( + a, b, + "trial {trial}: coefficient {k} differs between 8- and 9-point interpolants" + ); + } + // ...and it still passes through the omitted 9th point. + assert_eq!( + poly8.evaluate(EF::from_u64(8)), + evals9[8].1, + "trial {trial}: degree-7 interpolant fails at z=8" + ); + } +}