Skip to content

Commit

Permalink
feat: make reduce_u32 + split_u32 circuit friendly (Plonky3#296)
Browse files Browse the repository at this point in the history
* playground

* test

* feat: multi field challenger adjustments

* cleanup

* more cleanup

* update comment

* clean diff

* smaller changes

* clean up

* cleanup

* fix clippy

* clippy

* sample

* clippy

* clippy

* cleanup

* mfc

* import order

* fix daniel's comments

* horner style decomp

* feat: make sampling safer

* remove trailing spaces

* boom
  • Loading branch information
jtguibas committed Apr 3, 2024
1 parent 017aee5 commit 7e4f7af
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 70 deletions.
8 changes: 4 additions & 4 deletions challenger/src/grinding_challenger.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use p3_field::{Field, PrimeField, PrimeField64};
use p3_field::{Field, PrimeField, PrimeField32, PrimeField64};
use p3_maybe_rayon::prelude::*;
use p3_symmetric::CryptographicPermutation;
use tracing::instrument;

use crate::{CanObserve, CanSampleBits, DuplexChallenger, MultiFieldChallenger};
use crate::{CanObserve, CanSampleBits, DuplexChallenger, MultiField32Challenger};

pub trait GrindingChallenger:
CanObserve<Self::Witness> + CanSampleBits<usize> + Sync + Clone
Expand Down Expand Up @@ -38,9 +38,9 @@ where
}
}

impl<F, PF, P, const WIDTH: usize> GrindingChallenger for MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> GrindingChallenger for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand Down
64 changes: 32 additions & 32 deletions challenger/src/multi_field_challenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;

use p3_field::{reduce_64, split_64, ExtensionField, Field, PrimeField, PrimeField64};
use p3_field::{reduce_32, split_32, ExtensionField, Field, PrimeField, PrimeField32};
use p3_symmetric::{CryptographicPermutation, Hash};

use crate::{CanObserve, CanSample, CanSampleBits, FieldChallenger};

/// Given a cryptographic permutation that operates on `[Field; WIDTH]`, produces a challenger
/// that can observe and sample `PrimeField64` elements. Can also observe values with
/// `Hash<F, PF, N>` type.
/// A challenger that operates natively on PF but produces challenges of F: PrimeField32.
///
/// Used for optimizing the cost of recursive proof verification of STARKs in SNARKs.
///
/// SAFETY: There are some bias complications with using this challenger. In particular,
/// samples are actually random in [0, 2^64) and then reduced to be in F.
#[derive(Clone, Debug)]
pub struct MultiFieldChallenger<F, PF, P, const WIDTH: usize>
pub struct MultiField32Challenger<F, PF, P, const WIDTH: usize>
where
F: PrimeField64,
F: PrimeField32,
PF: Field,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand All @@ -24,17 +27,17 @@ where
num_f_elms: usize,
}

impl<F, PF, P, const WIDTH: usize> MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: Field,
P: CryptographicPermutation<[PF; WIDTH]>,
{
pub fn new(permutation: P) -> Result<Self, String> {
if F::order() >= PF::order() {
return Err(String::from("F::order() must be less than PF::order()"));
}
let num_f_elms = PF::bits() / F::bits();
let num_f_elms = PF::bits() / 64;
Ok(Self {
sponge_state: [PF::default(); WIDTH],
input_buffer: vec![],
Expand All @@ -45,17 +48,17 @@ where
}
}

impl<F, PF, P, const WIDTH: usize> MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn duplexing(&mut self) {
assert!(self.input_buffer.len() <= self.num_f_elms * WIDTH);

for (i, f_chunk) in self.input_buffer.chunks(self.num_f_elms).enumerate() {
self.sponge_state[i] = reduce_64(f_chunk);
self.sponge_state[i] = reduce_32(f_chunk);
}
self.input_buffer.clear();

Expand All @@ -64,27 +67,25 @@ where

self.output_buffer.clear();
for &pf_val in self.sponge_state.iter() {
let mut f_vals = split_64(pf_val);
f_vals.resize(self.num_f_elms, F::zero());

let f_vals = split_32(pf_val, self.num_f_elms);
for f_val in f_vals {
self.output_buffer.push(f_val);
}
}
}
}

impl<F, PF, P, const WIDTH: usize> FieldChallenger<F> for MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> FieldChallenger<F> for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
}

impl<F, PF, P, const WIDTH: usize> CanObserve<F> for MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> CanObserve<F> for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand All @@ -101,9 +102,9 @@ where
}

impl<F, PF, const N: usize, P, const WIDTH: usize> CanObserve<[F; N]>
for MultiFieldChallenger<F, PF, P, WIDTH>
for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand All @@ -115,17 +116,15 @@ where
}

impl<F, PF, const N: usize, P, const WIDTH: usize> CanObserve<Hash<F, PF, N>>
for MultiFieldChallenger<F, PF, P, WIDTH>
for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
fn observe(&mut self, values: Hash<F, PF, N>) {
for pf_val in values {
let mut f_vals = split_64(pf_val);
f_vals.resize(self.num_f_elms, F::zero());

let f_vals: Vec<F> = split_32(pf_val, self.num_f_elms);
for f_val in f_vals {
self.observe(f_val);
}
Expand All @@ -134,9 +133,10 @@ where
}

// for TrivialPcs
impl<F, PF, P, const WIDTH: usize> CanObserve<Vec<Vec<F>>> for MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> CanObserve<Vec<Vec<F>>>
for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand All @@ -149,9 +149,9 @@ where
}
}

impl<F, EF, PF, P, const WIDTH: usize> CanSample<EF> for MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, EF, PF, P, const WIDTH: usize> CanSample<EF> for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
EF: ExtensionField<F>,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
Expand All @@ -171,9 +171,9 @@ where
}
}

impl<F, PF, P, const WIDTH: usize> CanSampleBits<usize> for MultiFieldChallenger<F, PF, P, WIDTH>
impl<F, PF, P, const WIDTH: usize> CanSampleBits<usize> for MultiField32Challenger<F, PF, P, WIDTH>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand Down
50 changes: 26 additions & 24 deletions field/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use alloc::vec;
use alloc::vec::Vec;
use core::array;

use num_traits::identities::Zero;
use num_bigint::BigUint;

use crate::field::Field;
use crate::{AbstractField, PrimeField, PrimeField64, TwoAdicField};
use crate::{AbstractField, PrimeField, PrimeField32, TwoAdicField};

/// Computes `Z_H(x)`, where `Z_H` is the zerofier of a multiplicative subgroup of order `2^log_n`.
pub fn two_adic_subgroup_zerofier<F: TwoAdicField>(log_n: usize, x: F) -> F {
Expand Down Expand Up @@ -130,32 +130,34 @@ pub fn halve_u64<const P: u64>(input: u64) -> u64 {
}
}

/// Given a slice of SF elements, reduce them to a TF element.
pub fn reduce_64<SF: PrimeField64, TF: PrimeField>(vals: &[SF]) -> TF {
let alpha = TF::from_canonical_u64(SF::ORDER_U64);

let mut res = TF::zero();
/// Given a slice of SF elements, reduce them to a TF element using a 2^32-base decomposition.
pub fn reduce_32<SF: PrimeField32, TF: PrimeField>(vals: &[SF]) -> TF {
let po2 = TF::from_canonical_u64(1u64 << 32);
let mut result = TF::zero();
for val in vals.iter().rev() {
res = res * alpha + TF::from_canonical_u64(val.as_canonical_u64());
result = result * po2 + TF::from_canonical_u32(val.as_canonical_u32());
}

res
result
}

/// Given a SF elements, split them to a vec of TF elements.
pub fn split_64<SF: PrimeField, TF: PrimeField64>(val: SF) -> Vec<TF> {
let alpha = &SF::from_canonical_u64(TF::ORDER_U64).as_canonical_biguint();

let mut res = Vec::new();
/// Given an SF element, split it to a vector of TF elements using a 2^64-base decomposition.
///
/// We use a 2^64-base decomposition for a field of size ~2^32 because then the bias will be
/// at most ~1/2^32 for each element after the reduction.
pub fn split_32<SF: PrimeField, TF: PrimeField32>(val: SF, n: usize) -> Vec<TF> {
let po2 = BigUint::from(1u128 << 64);
let mut val = val.as_canonical_biguint();

while !val.is_zero() {
let rem = &val % alpha;
val /= alpha;

// Can assume there is one u64 digit since SF is PrimeField64.
res.push(TF::from_canonical_u64(rem.to_u64_digits()[0]));
let mut result = Vec::new();
for _ in 0..n {
let mask: BigUint = po2.clone() - BigUint::from(1u128);
let digit: BigUint = val.clone() & mask;
let digit_u64s = digit.to_u64_digits();
if !digit_u64s.is_empty() {
result.push(TF::from_wrapped_u64(digit_u64s[0]));
} else {
result.push(TF::zero())
}
val /= po2.clone();
}

res
result
}
2 changes: 1 addition & 1 deletion fri/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum FriError<CommitMmcsErr> {
#[derive(Debug)]
pub struct FriChallenges<F> {
pub query_indices: Vec<usize>,
betas: Vec<F>,
pub betas: Vec<F>,
}

pub fn verify_shape_and_sample_challenges<F, M, Challenger>(
Expand Down
18 changes: 9 additions & 9 deletions symmetric/src/sponge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use alloc::string::String;
use core::marker::PhantomData;

use itertools::Itertools;
use p3_field::{reduce_64, Field, PrimeField, PrimeField64};
use p3_field::{reduce_32, Field, PrimeField, PrimeField32};

use crate::hasher::CryptographicHasher;
use crate::permutation::CryptographicPermutation;
Expand Down Expand Up @@ -43,12 +43,12 @@ where
}
}

/// A padding-free, overwrite-mode sponge function. Accepts `PrimeField64` elements and has a permutation
/// using a different `Field` type.
/// A padding-free, overwrite-mode sponge function that operates natively over PF but accepts elements
/// of F: PrimeField32.
///
/// `WIDTH` is the sponge's rate plus the sponge's capacity.
#[derive(Clone, Debug)]
pub struct PaddingFreeSpongeMultiField<
pub struct MultiField32PaddingFreeSponge<
F,
PF,
P,
Expand All @@ -62,9 +62,9 @@ pub struct PaddingFreeSpongeMultiField<
}

impl<F, PF, P, const WIDTH: usize, const RATE: usize, const OUT: usize>
PaddingFreeSpongeMultiField<F, PF, P, WIDTH, RATE, OUT>
MultiField32PaddingFreeSponge<F, PF, P, WIDTH, RATE, OUT>
where
F: PrimeField64,
F: PrimeField32,
PF: Field,
{
pub fn new(permutation: P) -> Result<Self, String> {
Expand All @@ -82,9 +82,9 @@ where
}

impl<F, PF, P, const WIDTH: usize, const RATE: usize, const OUT: usize>
CryptographicHasher<F, [PF; OUT]> for PaddingFreeSpongeMultiField<F, PF, P, WIDTH, RATE, OUT>
CryptographicHasher<F, [PF; OUT]> for MultiField32PaddingFreeSponge<F, PF, P, WIDTH, RATE, OUT>
where
F: PrimeField64,
F: PrimeField32,
PF: PrimeField + Default + Copy,
P: CryptographicPermutation<[PF; WIDTH]>,
{
Expand All @@ -98,7 +98,7 @@ where
.into_iter()
.enumerate()
{
state[chunk_id] = reduce_64(&chunk.collect_vec());
state[chunk_id] = reduce_32(&chunk.collect_vec());
}
state = self.permutation.permute(state);
}
Expand Down

0 comments on commit 7e4f7af

Please sign in to comment.