Skip to content

Commit

Permalink
Default impl of permute rather than permute_mut (Plonky3#129)
Browse files Browse the repository at this point in the history
To encourage direct impls of the latter
  • Loading branch information
dlubarov committed Sep 27, 2023
1 parent 81bb4bc commit 3422c89
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 72 deletions.
9 changes: 6 additions & 3 deletions keccak/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ use tiny_keccak::{keccakf, Hasher, Keccak};
pub struct KeccakF;

impl Permutation<[u64; 25]> for KeccakF {
fn permute(&self, mut input: [u64; 25]) -> [u64; 25] {
keccakf(&mut input);
input
fn permute_mut(&self, input: &mut [u64; 25]) {
keccakf(input);
}
}

Expand All @@ -36,6 +35,10 @@ impl Permutation<[u8; 200]> for KeccakF {
u64_limb.to_le_bytes()[i % 8]
})
}

fn permute_mut(&self, input: &mut [u8; 200]) {
*input = self.permute(*input);
}
}

impl CryptographicPermutation<[u8; 200]> for KeccakF {}
Expand Down
24 changes: 24 additions & 0 deletions mds/src/babybear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ impl Permutation<[BabyBear; 8]> for MdsMatrixBabyBear {
fn permute(&self, input: [BabyBear; 8]) -> [BabyBear; 8] {
apply_circulant_8_sml(input)
}

fn permute_mut(&self, input: &mut [BabyBear; 8]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<BabyBear, 8> for MdsMatrixBabyBear {}

impl Permutation<[BabyBear; 12]> for MdsMatrixBabyBear {
fn permute(&self, input: [BabyBear; 12]) -> [BabyBear; 12] {
apply_circulant_12_sml(input)
}

fn permute_mut(&self, input: &mut [BabyBear; 12]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<BabyBear, 12> for MdsMatrixBabyBear {}

Expand All @@ -46,6 +54,10 @@ impl Permutation<[BabyBear; 16]> for MdsMatrixBabyBear {
const ENTRIES: [u64; 16] = first_row_to_first_col(&MATRIX_CIRC_MDS_16_BABYBEAR);
apply_circulant_fft(FFT_ALGO, ENTRIES, &input)
}

fn permute_mut(&self, input: &mut [BabyBear; 16]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<BabyBear, 16> for MdsMatrixBabyBear {}

Expand All @@ -63,6 +75,10 @@ impl Permutation<[BabyBear; 24]> for MdsMatrixBabyBear {
fn permute(&self, input: [BabyBear; 24]) -> [BabyBear; 24] {
apply_circulant(&MATRIX_CIRC_MDS_24_BABYBEAR, input)
}

fn permute_mut(&self, input: &mut [BabyBear; 24]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<BabyBear, 24> for MdsMatrixBabyBear {}

Expand All @@ -83,6 +99,10 @@ impl Permutation<[BabyBear; 32]> for MdsMatrixBabyBear {
const ENTRIES: [u64; 32] = first_row_to_first_col(&MATRIX_CIRC_MDS_32_BABYBEAR);
apply_circulant_fft(FFT_ALGO, ENTRIES, &input)
}

fn permute_mut(&self, input: &mut [BabyBear; 32]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<BabyBear, 32> for MdsMatrixBabyBear {}

Expand Down Expand Up @@ -111,6 +131,10 @@ impl Permutation<[BabyBear; 64]> for MdsMatrixBabyBear {
const ENTRIES: [u64; 64] = first_row_to_first_col(&MATRIX_CIRC_MDS_64_BABYBEAR);
apply_circulant_fft(FFT_ALGO, ENTRIES, &input)
}

fn permute_mut(&self, input: &mut [BabyBear; 64]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<BabyBear, 64> for MdsMatrixBabyBear {}

Expand Down
28 changes: 28 additions & 0 deletions mds/src/goldilocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ impl Permutation<[Goldilocks; 8]> for MdsMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 8]) -> [Goldilocks; 8] {
apply_circulant_8_sml(input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 8]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 8> for MdsMatrixGoldilocks {}

impl Permutation<[Goldilocks; 12]> for MdsMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 12]) -> [Goldilocks; 12] {
apply_circulant_12_sml(input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 12]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 12> for MdsMatrixGoldilocks {}

Expand All @@ -46,6 +54,10 @@ impl Permutation<[Goldilocks; 16]> for MdsMatrixGoldilocks {
const ENTRIES: [u64; 16] = first_row_to_first_col(&MATRIX_CIRC_MDS_16_GOLDILOCKS);
apply_circulant_fft(FFT_ALGO, ENTRIES, &input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 16]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 16> for MdsMatrixGoldilocks {}

Expand All @@ -63,6 +75,10 @@ impl Permutation<[Goldilocks; 24]> for MdsMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 24]) -> [Goldilocks; 24] {
apply_circulant(&MATRIX_CIRC_MDS_24_GOLDILOCKS, input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 24]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 24> for MdsMatrixGoldilocks {}

Expand All @@ -83,6 +99,10 @@ impl Permutation<[Goldilocks; 32]> for MdsMatrixGoldilocks {
const ENTRIES: [u64; 32] = first_row_to_first_col(&MATRIX_CIRC_MDS_32_GOLDILOCKS);
apply_circulant_fft(FFT_ALGO, ENTRIES, &input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 32]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 32> for MdsMatrixGoldilocks {}

Expand Down Expand Up @@ -111,6 +131,10 @@ impl Permutation<[Goldilocks; 64]> for MdsMatrixGoldilocks {
const ENTRIES: [u64; 64] = first_row_to_first_col(&MATRIX_CIRC_MDS_64_GOLDILOCKS);
apply_circulant_fft(FFT_ALGO, ENTRIES, &input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 64]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 64> for MdsMatrixGoldilocks {}

Expand Down Expand Up @@ -139,6 +163,10 @@ impl Permutation<[Goldilocks; 68]> for MdsMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 68]) -> [Goldilocks; 68] {
apply_circulant(&MATRIX_CIRC_MDS_68_GOLDILOCKS, input)
}

fn permute_mut(&self, input: &mut [Goldilocks; 68]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Goldilocks, 68> for MdsMatrixGoldilocks {}

Expand Down
20 changes: 20 additions & 0 deletions mds/src/mersenne31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@ impl Permutation<[Mersenne31; 8]> for MdsMatrixMersenne31 {
fn permute(&self, input: [Mersenne31; 8]) -> [Mersenne31; 8] {
apply_circulant_8_sml(input)
}

fn permute_mut(&self, input: &mut [Mersenne31; 8]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Mersenne31, 8> for MdsMatrixMersenne31 {}

impl Permutation<[Mersenne31; 12]> for MdsMatrixMersenne31 {
fn permute(&self, input: [Mersenne31; 12]) -> [Mersenne31; 12] {
apply_circulant_12_sml(input)
}

fn permute_mut(&self, input: &mut [Mersenne31; 12]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Mersenne31, 12> for MdsMatrixMersenne31 {}

Expand All @@ -39,6 +47,10 @@ impl Permutation<[Mersenne31; 16]> for MdsMatrixMersenne31 {
fn permute(&self, input: [Mersenne31; 16]) -> [Mersenne31; 16] {
apply_circulant(&MATRIX_CIRC_MDS_16_MERSENNE31, input)
}

fn permute_mut(&self, input: &mut [Mersenne31; 16]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Mersenne31, 16> for MdsMatrixMersenne31 {}

Expand All @@ -58,6 +70,10 @@ impl Permutation<[Mersenne31; 32]> for MdsMatrixMersenne31 {
fn permute(&self, input: [Mersenne31; 32]) -> [Mersenne31; 32] {
apply_circulant(&MATRIX_CIRC_MDS_32_MERSENNE31, input)
}

fn permute_mut(&self, input: &mut [Mersenne31; 32]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Mersenne31, 32> for MdsMatrixMersenne31 {}

Expand Down Expand Up @@ -85,6 +101,10 @@ impl Permutation<[Mersenne31; 64]> for MdsMatrixMersenne31 {
fn permute(&self, input: [Mersenne31; 64]) -> [Mersenne31; 64] {
apply_circulant(&MATRIX_CIRC_MDS_64_MERSENNE31, input)
}

fn permute_mut(&self, input: &mut [Mersenne31; 64]) {
*input = self.permute(*input);
}
}
impl MdsPermutation<Mersenne31, 64> for MdsMatrixMersenne31 {}

Expand Down
4 changes: 4 additions & 0 deletions monolith/src/monolith_mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ impl<const WIDTH: usize, const NUM_ROUNDS: usize> Permutation<[Mersenne31; WIDTH
apply_cauchy_mds_matrix(&mut shake_finalized, input)
}
}

fn permute_mut(&self, input: &mut [Mersenne31; WIDTH]) {
*input = self.permute(*input);
}
}

impl<const WIDTH: usize, const NUM_ROUNDS: usize> MdsPermutation<Mersenne31, WIDTH>
Expand Down
9 changes: 4 additions & 5 deletions poseidon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,11 @@ where
F::F: PrimeField,
Mds: MdsPermutation<F, WIDTH>,
{
fn permute(&self, mut state: [F; WIDTH]) -> [F; WIDTH] {
fn permute_mut(&self, state: &mut [F; WIDTH]) {
let mut round_ctr = 0;
self.half_full_rounds(&mut state, &mut round_ctr);
self.partial_rounds(&mut state, &mut round_ctr);
self.half_full_rounds(&mut state, &mut round_ctr);
state
self.half_full_rounds(state, &mut round_ctr);
self.partial_rounds(state, &mut round_ctr);
self.half_full_rounds(state, &mut round_ctr);
}
}

Expand Down
14 changes: 6 additions & 8 deletions poseidon2/src/babybear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,17 @@ pub const MATRIX_DIAG_24_BABYBEAR: [u64; 24] = [
pub struct DiffusionMatrixBabybear;

impl Permutation<[BabyBear; 16]> for DiffusionMatrixBabybear {
fn permute(&self, input: [BabyBear; 16]) -> [BabyBear; 16] {
let mut input = input;
matmul_internal::<BabyBear, 16>(&mut input, MATRIX_DIAG_16_BABYBEAR);
input
fn permute_mut(&self, state: &mut [BabyBear; 16]) {
matmul_internal::<BabyBear, 16>(state, MATRIX_DIAG_16_BABYBEAR);
}
}

impl DiffusionPermutation<BabyBear, 16> for DiffusionMatrixBabybear {}

impl Permutation<[BabyBear; 24]> for DiffusionMatrixBabybear {
fn permute(&self, input: [BabyBear; 24]) -> [BabyBear; 24] {
let mut input = input;
matmul_internal::<BabyBear, 24>(&mut input, MATRIX_DIAG_24_BABYBEAR);
input
fn permute_mut(&self, state: &mut [BabyBear; 24]) {
matmul_internal::<BabyBear, 24>(state, MATRIX_DIAG_24_BABYBEAR);
}
}

impl DiffusionPermutation<BabyBear, 24> for DiffusionMatrixBabybear {}
28 changes: 12 additions & 16 deletions poseidon2/src/goldilocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,37 +80,33 @@ pub const MATRIX_DIAG_20_GOLDILOCKS: [u64; 20] = [
pub struct DiffusionMatrixGoldilocks;

impl Permutation<[Goldilocks; 8]> for DiffusionMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 8]) -> [Goldilocks; 8] {
let mut input = input;
matmul_internal::<Goldilocks, 8>(&mut input, MATRIX_DIAG_8_GOLDILOCKS);
input
fn permute_mut(&self, state: &mut [Goldilocks; 8]) {
matmul_internal::<Goldilocks, 8>(state, MATRIX_DIAG_8_GOLDILOCKS);
}
}

impl DiffusionPermutation<Goldilocks, 8> for DiffusionMatrixGoldilocks {}

impl Permutation<[Goldilocks; 12]> for DiffusionMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 12]) -> [Goldilocks; 12] {
let mut input = input;
matmul_internal::<Goldilocks, 12>(&mut input, MATRIX_DIAG_12_GOLDILOCKS);
input
fn permute_mut(&self, state: &mut [Goldilocks; 12]) {
matmul_internal::<Goldilocks, 12>(state, MATRIX_DIAG_12_GOLDILOCKS);
}
}

impl DiffusionPermutation<Goldilocks, 12> for DiffusionMatrixGoldilocks {}

impl Permutation<[Goldilocks; 16]> for DiffusionMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 16]) -> [Goldilocks; 16] {
let mut input = input;
matmul_internal::<Goldilocks, 16>(&mut input, MATRIX_DIAG_16_GOLDILOCKS);
input
fn permute_mut(&self, state: &mut [Goldilocks; 16]) {
matmul_internal::<Goldilocks, 16>(state, MATRIX_DIAG_16_GOLDILOCKS);
}
}

impl DiffusionPermutation<Goldilocks, 16> for DiffusionMatrixGoldilocks {}

impl Permutation<[Goldilocks; 20]> for DiffusionMatrixGoldilocks {
fn permute(&self, input: [Goldilocks; 20]) -> [Goldilocks; 20] {
let mut input = input;
matmul_internal::<Goldilocks, 20>(&mut input, MATRIX_DIAG_20_GOLDILOCKS);
input
fn permute_mut(&self, state: &mut [Goldilocks; 20]) {
matmul_internal::<Goldilocks, 20>(state, MATRIX_DIAG_20_GOLDILOCKS);
}
}

impl DiffusionPermutation<Goldilocks, 20> for DiffusionMatrixGoldilocks {}
27 changes: 10 additions & 17 deletions poseidon2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ where
let mut constants = Vec::new();
let rounds = rounds_f + rounds_p;
for _ in 0..rounds {
let mut round_constant = [F::F::ZERO; WIDTH];
#[allow(clippy::needless_range_loop)]
for j in 0..WIDTH {
round_constant[j] = rng.sample(Standard);
}
constants.push(round_constant);
constants.push(rng.gen::<[F::F; WIDTH]>());
}

Self {
Expand Down Expand Up @@ -127,35 +122,33 @@ where
Mds: MdsPermutation<F, WIDTH>,
Diffusion: DiffusionPermutation<F, WIDTH>,
{
fn permute(&self, mut state: [F; WIDTH]) -> [F; WIDTH] {
fn permute_mut(&self, state: &mut [F; WIDTH]) {
// The initial linear layer.
self.external_linear_layer.permute_mut(&mut state);
self.external_linear_layer.permute_mut(state);

// The first half of the external rounds.
let rounds = self.rounds_f + self.rounds_p;
let rounds_f_beggining = self.rounds_f / 2;
for r in 0..rounds_f_beggining {
self.add_rc(&mut state, &self.constants[r]);
self.sbox(&mut state);
self.external_linear_layer.permute_mut(&mut state);
self.add_rc(state, &self.constants[r]);
self.sbox(state);
self.external_linear_layer.permute_mut(state);
}

// The internal rounds.
let p_end = rounds_f_beggining + self.rounds_p;
for r in self.rounds_f..p_end {
state[0] += self.constants[r][0];
state[0] = self.sbox_p(&state[0]);
self.internal_linear_layer.permute_mut(&mut state);
self.internal_linear_layer.permute_mut(state);
}

// The second half of the external rounds.
for r in p_end..rounds {
self.add_rc(&mut state, &self.constants[r]);
self.sbox(&mut state);
self.external_linear_layer.permute_mut(&mut state);
self.add_rc(state, &self.constants[r]);
self.sbox(state);
self.external_linear_layer.permute_mut(state);
}

state
}
}

Expand Down
Loading

0 comments on commit 3422c89

Please sign in to comment.