Skip to content

Commit

Permalink
Some simplifications of vartime division (#661)
Browse files Browse the repository at this point in the history
Second attempt at #646
  • Loading branch information
fjarri authored Aug 21, 2024
1 parent 325e16b commit 52fee04
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 201 deletions.
123 changes: 78 additions & 45 deletions src/uint/boxed/div.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! [`BoxedUint`] division operations.

use crate::{
uint::{boxed, div_limb::div3by2},
uint::{
boxed,
div_limb::{div2by1, div3by2},
},
BoxedUint, CheckedDiv, ConstChoice, ConstantTimeSelect, DivRemLimb, Limb, NonZero, Reciprocal,
RemLimb, Word, Wrapping,
RemLimb, Wrapping,
};
use core::ops::{Div, DivAssign, Rem, RemAssign};
use subtle::CtOption;
Expand Down Expand Up @@ -166,7 +169,7 @@ impl BoxedUint {

while xi > 0 {
// Divide high dividend words by the high divisor word to estimate the quotient word
let (mut quo, _) = div3by2(x_hi.0, x_lo.0, x[xi - 1].0, &reciprocal, y[size - 2].0);
let mut quo = div3by2(x_hi.0, x_lo.0, x[xi - 1].0, &reciprocal, y[size - 2].0);

// This loop is a no-op once xi is smaller than the number of words in the divisor
let done = ConstChoice::from_u32_lt(xi as u32, dwords - 1);
Expand Down Expand Up @@ -206,14 +209,20 @@ impl BoxedUint {
}

let limb_div = ConstChoice::from_u32_eq(1, dwords);

// Calculate quotient and remainder for the case where the divisor is a single word
let (quo2, rem2) = div3by2(x_hi.0, x_lo.0, 0, &reciprocal, 0);
// Note that `div2by1()` will panic if `x_hi >= reciprocal.divisor_normalized`,
// but this can only be the case if `limb_div` is falsy,
// in which case we discard the result anyway,
// so we conditionally set `x_hi` to zero for this branch.
let x_hi_adjusted = Limb::select(Limb::ZERO, x_hi, limb_div);
let (quo2, rem2) = div2by1(x_hi_adjusted.0, x_lo.0, &reciprocal);

// Adjust the quotient for single limb division
x[0] = Limb::select(x[0], Limb(quo2), limb_div);

// Copy out the remainder
y[0] = Limb::select(x[0], Limb(rem2 as Word), limb_div);
y[0] = Limb::select(x[0], Limb(rem2), limb_div);
i = 1;
while i < size {
y[i] = Limb::select(Limb::ZERO, x[i], ConstChoice::from_u32_lt(i as u32, dwords));
Expand Down Expand Up @@ -382,6 +391,42 @@ impl RemLimb for BoxedUint {
}
}

/// Computes `limbs << shift` inplace, where `0 <= shift < Limb::BITS`, returning the carry.
fn shl_limb_vartime(limbs: &mut [Limb], shift: u32) -> Limb {
if shift == 0 {
return Limb::ZERO;
}

let lshift = shift;
let rshift = Limb::BITS - shift;
let limbs_num = limbs.len();

let carry = limbs[limbs_num - 1] >> rshift;
for i in (1..limbs_num).rev() {
limbs[i] = (limbs[i] << lshift) | (limbs[i - 1] >> rshift);
}
limbs[0] <<= lshift;

carry
}

/// Computes `limbs >> shift` inplace, where `0 <= shift < Limb::BITS`.
fn shr_limb_vartime(limbs: &mut [Limb], shift: u32) {
if shift == 0 {
return;
}

let lshift = Limb::BITS - shift;
let rshift = shift;

let limbs_num = limbs.len();

for i in 0..limbs_num - 1 {
limbs[i] = (limbs[i] >> rshift) | (limbs[i + 1] << lshift);
}
limbs[limbs_num - 1] >>= rshift;
}

/// Computes `x` / `y`, returning the quotient in `x` and the remainder in `y`.
///
/// This function operates in variable-time. It will panic if the divisor is zero
Expand All @@ -408,51 +453,44 @@ pub(crate) fn div_rem_vartime_in_place(x: &mut [Limb], y: &mut [Limb]) {
}

let lshift = y[yc - 1].leading_zeros();
let rshift = if lshift == 0 { 0 } else { Limb::BITS - lshift };
let mut x_hi = Limb::ZERO;
let mut carry;

if lshift != 0 {
// Shift divisor such that it has no leading zeros
// This means that div2by1 requires no extra shifts, and ensures that the high word >= b/2
carry = Limb::ZERO;
for i in 0..yc {
(y[i], carry) = (Limb((y[i].0 << lshift) | carry.0), Limb(y[i].0 >> rshift));
}

// Shift the dividend to match
carry = Limb::ZERO;
for i in 0..xc {
(x[i], carry) = (Limb((x[i].0 << lshift) | carry.0), Limb(x[i].0 >> rshift));
}
x_hi = carry;
}
// Shift divisor such that it has no leading zeros
// This means that div2by1 requires no extra shifts, and ensures that the high word >= b/2
shl_limb_vartime(y, lshift);

// Shift the dividend to match
let mut x_hi = shl_limb_vartime(x, lshift);

let reciprocal = Reciprocal::new(y[yc - 1].to_nz().expect("zero divisor"));

for xi in (yc - 1..xc).rev() {
// Divide high dividend words by the high divisor word to estimate the quotient word
let (mut quo, _) = div3by2(x_hi.0, x[xi].0, x[xi - 1].0, &reciprocal, y[yc - 2].0);
let mut quo = div3by2(x_hi.0, x[xi].0, x[xi - 1].0, &reciprocal, y[yc - 2].0);

// Subtract q*divisor from the dividend
carry = Limb::ZERO;
let mut borrow = Limb::ZERO;
let mut tmp;
for i in 0..yc {
(tmp, carry) = Limb::ZERO.mac(y[i], Limb(quo), carry);
(x[xi + i + 1 - yc], borrow) = x[xi + i + 1 - yc].sbb(tmp, borrow);
}
(_, borrow) = x_hi.sbb(carry, borrow);
let borrow = {
let mut carry = Limb::ZERO;
let mut borrow = Limb::ZERO;
let mut tmp;
for i in 0..yc {
(tmp, carry) = Limb::ZERO.mac(y[i], Limb(quo), carry);
(x[xi + i + 1 - yc], borrow) = x[xi + i + 1 - yc].sbb(tmp, borrow);
}
(_, borrow) = x_hi.sbb(carry, borrow);
borrow
};

// If the subtraction borrowed, then decrement q and add back the divisor
// The probability of this being needed is very low, about 2/(Limb::MAX+1)
let ct_borrow = ConstChoice::from_word_mask(borrow.0);
carry = Limb::ZERO;
for i in 0..yc {
(x[xi + i + 1 - yc], carry) =
x[xi + i + 1 - yc].adc(Limb::select(Limb::ZERO, y[i], ct_borrow), carry);
}
quo = ct_borrow.select_word(quo, quo.saturating_sub(1));
quo = {
let ct_borrow = ConstChoice::from_word_mask(borrow.0);
let mut carry = Limb::ZERO;
for i in 0..yc {
(x[xi + i + 1 - yc], carry) =
x[xi + i + 1 - yc].adc(Limb::select(Limb::ZERO, y[i], ct_borrow), carry);
}
ct_borrow.select_word(quo, quo.wrapping_sub(1))
};

// Store the quotient within dividend and set x_hi to the current highest word
x_hi = x[xi];
Expand All @@ -464,12 +502,7 @@ pub(crate) fn div_rem_vartime_in_place(x: &mut [Limb], y: &mut [Limb]) {
y[yc - 1] = x_hi;

// Unshift the remainder from the earlier adjustment
if lshift != 0 {
carry = Limb::ZERO;
for i in (0..yc).rev() {
(y[i], carry) = (Limb((y[i].0 >> lshift) | carry.0), Limb(y[i].0 << rshift));
}
}
shr_limb_vartime(y, lshift);

// Shift the quotient to the low limbs within dividend
// let x_size = xc - yc + 1;
Expand Down
Loading

0 comments on commit 52fee04

Please sign in to comment.