From 7e4f7afedeffbc49c9c123e21f29e0eb1a9b2227 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Wed, 3 Apr 2024 11:50:55 -0700 Subject: [PATCH] feat: make reduce_u32 + split_u32 circuit friendly (#296) * playground * test * feat: multi field challenger adjustments * cleanup * more cleanup * update comment * clean diff * smaller changes * clean up * cleanup * fix clippy * clippy * sample * clippy * clippy * cleanup * mfc * import order * fix daniel's comments * horner style decomp * feat: make sampling safer * remove trailing spaces * boom --- challenger/src/grinding_challenger.rs | 8 +-- challenger/src/multi_field_challenger.rs | 64 ++++++++++++------------ field/src/helpers.rs | 50 +++++++++--------- fri/src/verifier.rs | 2 +- symmetric/src/sponge.rs | 18 +++---- 5 files changed, 72 insertions(+), 70 deletions(-) diff --git a/challenger/src/grinding_challenger.rs b/challenger/src/grinding_challenger.rs index 5425b44eb..8ac9735ad 100644 --- a/challenger/src/grinding_challenger.rs +++ b/challenger/src/grinding_challenger.rs @@ -1,9 +1,9 @@ -use p3_field::{Field, PrimeField, PrimeField64}; +use p3_field::{Field, PrimeField, PrimeField32, PrimeField64}; use p3_maybe_rayon::prelude::*; use p3_symmetric::CryptographicPermutation; use tracing::instrument; -use crate::{CanObserve, CanSampleBits, DuplexChallenger, MultiFieldChallenger}; +use crate::{CanObserve, CanSampleBits, DuplexChallenger, MultiField32Challenger}; pub trait GrindingChallenger: CanObserve + CanSampleBits + Sync + Clone @@ -38,9 +38,9 @@ where } } -impl GrindingChallenger for MultiFieldChallenger +impl GrindingChallenger for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { diff --git a/challenger/src/multi_field_challenger.rs b/challenger/src/multi_field_challenger.rs index 1b8b72738..9197b75c4 100644 --- a/challenger/src/multi_field_challenger.rs +++ b/challenger/src/multi_field_challenger.rs @@ -2,18 +2,21 @@ use alloc::string::String; use alloc::vec; use alloc::vec::Vec; -use p3_field::{reduce_64, split_64, ExtensionField, Field, PrimeField, PrimeField64}; +use p3_field::{reduce_32, split_32, ExtensionField, Field, PrimeField, PrimeField32}; use p3_symmetric::{CryptographicPermutation, Hash}; use crate::{CanObserve, CanSample, CanSampleBits, FieldChallenger}; -/// Given a cryptographic permutation that operates on `[Field; WIDTH]`, produces a challenger -/// that can observe and sample `PrimeField64` elements. Can also observe values with -/// `Hash` type. +/// A challenger that operates natively on PF but produces challenges of F: PrimeField32. +/// +/// Used for optimizing the cost of recursive proof verification of STARKs in SNARKs. +/// +/// SAFETY: There are some bias complications with using this challenger. In particular, +/// samples are actually random in [0, 2^64) and then reduced to be in F. #[derive(Clone, Debug)] -pub struct MultiFieldChallenger +pub struct MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: Field, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -24,9 +27,9 @@ where num_f_elms: usize, } -impl MultiFieldChallenger +impl MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: Field, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -34,7 +37,7 @@ where if F::order() >= PF::order() { return Err(String::from("F::order() must be less than PF::order()")); } - let num_f_elms = PF::bits() / F::bits(); + let num_f_elms = PF::bits() / 64; Ok(Self { sponge_state: [PF::default(); WIDTH], input_buffer: vec![], @@ -45,9 +48,9 @@ where } } -impl MultiFieldChallenger +impl MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -55,7 +58,7 @@ where assert!(self.input_buffer.len() <= self.num_f_elms * WIDTH); for (i, f_chunk) in self.input_buffer.chunks(self.num_f_elms).enumerate() { - self.sponge_state[i] = reduce_64(f_chunk); + self.sponge_state[i] = reduce_32(f_chunk); } self.input_buffer.clear(); @@ -64,9 +67,7 @@ where self.output_buffer.clear(); for &pf_val in self.sponge_state.iter() { - let mut f_vals = split_64(pf_val); - f_vals.resize(self.num_f_elms, F::zero()); - + let f_vals = split_32(pf_val, self.num_f_elms); for f_val in f_vals { self.output_buffer.push(f_val); } @@ -74,17 +75,17 @@ where } } -impl FieldChallenger for MultiFieldChallenger +impl FieldChallenger for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { } -impl CanObserve for MultiFieldChallenger +impl CanObserve for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -101,9 +102,9 @@ where } impl CanObserve<[F; N]> - for MultiFieldChallenger + for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -115,17 +116,15 @@ where } impl CanObserve> - for MultiFieldChallenger + for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { fn observe(&mut self, values: Hash) { for pf_val in values { - let mut f_vals = split_64(pf_val); - f_vals.resize(self.num_f_elms, F::zero()); - + let f_vals: Vec = split_32(pf_val, self.num_f_elms); for f_val in f_vals { self.observe(f_val); } @@ -134,9 +133,10 @@ where } // for TrivialPcs -impl CanObserve>> for MultiFieldChallenger +impl CanObserve>> + for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -149,9 +149,9 @@ where } } -impl CanSample for MultiFieldChallenger +impl CanSample for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, EF: ExtensionField, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, @@ -171,9 +171,9 @@ where } } -impl CanSampleBits for MultiFieldChallenger +impl CanSampleBits for MultiField32Challenger where - F: PrimeField64, + F: PrimeField32, PF: PrimeField, P: CryptographicPermutation<[PF; WIDTH]>, { diff --git a/field/src/helpers.rs b/field/src/helpers.rs index 9443821d1..2ad2d902c 100644 --- a/field/src/helpers.rs +++ b/field/src/helpers.rs @@ -2,10 +2,10 @@ use alloc::vec; use alloc::vec::Vec; use core::array; -use num_traits::identities::Zero; +use num_bigint::BigUint; use crate::field::Field; -use crate::{AbstractField, PrimeField, PrimeField64, TwoAdicField}; +use crate::{AbstractField, PrimeField, PrimeField32, TwoAdicField}; /// Computes `Z_H(x)`, where `Z_H` is the zerofier of a multiplicative subgroup of order `2^log_n`. pub fn two_adic_subgroup_zerofier(log_n: usize, x: F) -> F { @@ -130,32 +130,34 @@ pub fn halve_u64(input: u64) -> u64 { } } -/// Given a slice of SF elements, reduce them to a TF element. -pub fn reduce_64(vals: &[SF]) -> TF { - let alpha = TF::from_canonical_u64(SF::ORDER_U64); - - let mut res = TF::zero(); +/// Given a slice of SF elements, reduce them to a TF element using a 2^32-base decomposition. +pub fn reduce_32(vals: &[SF]) -> TF { + let po2 = TF::from_canonical_u64(1u64 << 32); + let mut result = TF::zero(); for val in vals.iter().rev() { - res = res * alpha + TF::from_canonical_u64(val.as_canonical_u64()); + result = result * po2 + TF::from_canonical_u32(val.as_canonical_u32()); } - - res + result } -/// Given a SF elements, split them to a vec of TF elements. -pub fn split_64(val: SF) -> Vec { - let alpha = &SF::from_canonical_u64(TF::ORDER_U64).as_canonical_biguint(); - - let mut res = Vec::new(); +/// Given an SF element, split it to a vector of TF elements using a 2^64-base decomposition. +/// +/// We use a 2^64-base decomposition for a field of size ~2^32 because then the bias will be +/// at most ~1/2^32 for each element after the reduction. +pub fn split_32(val: SF, n: usize) -> Vec { + let po2 = BigUint::from(1u128 << 64); let mut val = val.as_canonical_biguint(); - - while !val.is_zero() { - let rem = &val % alpha; - val /= alpha; - - // Can assume there is one u64 digit since SF is PrimeField64. - res.push(TF::from_canonical_u64(rem.to_u64_digits()[0])); + let mut result = Vec::new(); + for _ in 0..n { + let mask: BigUint = po2.clone() - BigUint::from(1u128); + let digit: BigUint = val.clone() & mask; + let digit_u64s = digit.to_u64_digits(); + if !digit_u64s.is_empty() { + result.push(TF::from_wrapped_u64(digit_u64s[0])); + } else { + result.push(TF::zero()) + } + val /= po2.clone(); } - - res + result } diff --git a/fri/src/verifier.rs b/fri/src/verifier.rs index 3aee9044b..afc770e39 100644 --- a/fri/src/verifier.rs +++ b/fri/src/verifier.rs @@ -21,7 +21,7 @@ pub enum FriError { #[derive(Debug)] pub struct FriChallenges { pub query_indices: Vec, - betas: Vec, + pub betas: Vec, } pub fn verify_shape_and_sample_challenges( diff --git a/symmetric/src/sponge.rs b/symmetric/src/sponge.rs index 161c0b9e8..42af48494 100644 --- a/symmetric/src/sponge.rs +++ b/symmetric/src/sponge.rs @@ -2,7 +2,7 @@ use alloc::string::String; use core::marker::PhantomData; use itertools::Itertools; -use p3_field::{reduce_64, Field, PrimeField, PrimeField64}; +use p3_field::{reduce_32, Field, PrimeField, PrimeField32}; use crate::hasher::CryptographicHasher; use crate::permutation::CryptographicPermutation; @@ -43,12 +43,12 @@ where } } -/// A padding-free, overwrite-mode sponge function. Accepts `PrimeField64` elements and has a permutation -/// using a different `Field` type. +/// A padding-free, overwrite-mode sponge function that operates natively over PF but accepts elements +/// of F: PrimeField32. /// /// `WIDTH` is the sponge's rate plus the sponge's capacity. #[derive(Clone, Debug)] -pub struct PaddingFreeSpongeMultiField< +pub struct MultiField32PaddingFreeSponge< F, PF, P, @@ -62,9 +62,9 @@ pub struct PaddingFreeSpongeMultiField< } impl - PaddingFreeSpongeMultiField + MultiField32PaddingFreeSponge where - F: PrimeField64, + F: PrimeField32, PF: Field, { pub fn new(permutation: P) -> Result { @@ -82,9 +82,9 @@ where } impl - CryptographicHasher for PaddingFreeSpongeMultiField + CryptographicHasher for MultiField32PaddingFreeSponge where - F: PrimeField64, + F: PrimeField32, PF: PrimeField + Default + Copy, P: CryptographicPermutation<[PF; WIDTH]>, { @@ -98,7 +98,7 @@ where .into_iter() .enumerate() { - state[chunk_id] = reduce_64(&chunk.collect_vec()); + state[chunk_id] = reduce_32(&chunk.collect_vec()); } state = self.permutation.permute(state); }