Skip to content

Commit

Permalink
[EC] Unify scalar multiplication for P-256/384/521 (#1693)
Browse files Browse the repository at this point in the history
Added unified scalar multiplication for curves implemented in ec_nistp.
This is exactly the same algorithm that was previously implemented
separately in p384.c and p521.c (p256.c implemented a different
algorithm previously).
  • Loading branch information
dkostic authored Jul 17, 2024
1 parent fce2b0c commit 9431f99
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 316 deletions.
166 changes: 155 additions & 11 deletions crypto/fipsmodule/ec/ec_nistp.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// | 1. | x | x | x* |
// | 2. | x | x | x* |
// | 3. | | | |
// | 4. | | | |
// | 4. | x | x | x* |
// | 5. | | | |
// * For P-256, only the Fiat-crypto implementation in p256.c is replaced.

Expand All @@ -30,11 +30,11 @@
// for the moment, this will be fixed when we migrate the whole P-521
// implementation to ec_nistp.c.
#if defined(EC_NISTP_USE_64BIT_LIMB)
#define NISTP_FELEM_MAX_NUM_OF_LIMBS (9)
#define FELEM_MAX_NUM_OF_LIMBS (9)
#else
#define NISTP_FELEM_MAX_NUM_OF_LIMBS (19)
#define FELEM_MAX_NUM_OF_LIMBS (19)
#endif
typedef ec_nistp_felem_limb ec_nistp_felem[NISTP_FELEM_MAX_NUM_OF_LIMBS];
typedef ec_nistp_felem_limb ec_nistp_felem[FELEM_MAX_NUM_OF_LIMBS];

// Conditional copy in constant-time (out = t == 0 ? z : nz).
static void cmovznz(ec_nistp_felem_limb *out,
Expand Down Expand Up @@ -280,8 +280,8 @@ static int16_t get_bit(const EC_SCALAR *in, size_t i) {
// It forces an odd scalar and outputs digits in
// {\pm 1, \pm 3, \pm 5, \pm 7, \pm 9, ...}
// i.e. signed odd digits with _no zeroes_ -- that makes it "regular".
void scalar_rwnaf(int16_t *out, size_t window_size,
const EC_SCALAR *scalar, size_t scalar_bit_size) {
static void scalar_rwnaf(int16_t *out, size_t window_size,
const EC_SCALAR *scalar, size_t scalar_bit_size) {
assert(window_size < 14);

// The assert above ensures this works correctly.
Expand All @@ -304,13 +304,30 @@ void scalar_rwnaf(int16_t *out, size_t window_size,
out[num_windows - 1] = window;
}

// The window size for scalar multiplication is hard coded for now.
#define SCALAR_MUL_WINDOW_SIZE (5)
#define SCALAR_MUL_TABLE_NUM_POINTS (1 << (SCALAR_MUL_WINDOW_SIZE - 1))

// To avoid dynamic allocation and freeing of memory in functions below
// we define maximum values of certain variables.
//
// The maximum number of limbs the table in |ec_nistp_scalar_mul| can have.
// Each point in the table has 3 coordinates that are field elements,
// and each field element has a defined maximum number of limbs.
#define SCALAR_MUL_TABLE_MAX_NUM_FELEM_LIMBS \
(SCALAR_MUL_TABLE_NUM_POINTS * 3 * FELEM_MAX_NUM_OF_LIMBS)

// Maximum number of windows (digits) for a scalar encoding which is
// determined by the maximum scalar bit size -- 521 bits in our case.
#define SCALAR_MUL_MAX_NUM_WINDOWS DIV_AND_CEIL(521, SCALAR_MUL_WINDOW_SIZE)

// Generate table of multiples of the input point P = (x_in, y_in, z_in):
// table <-- [2i + 1]P for i in [0, SCALAR_MUL_TABLE_NUM_POINTS - 1].
void generate_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *table,
ec_nistp_felem_limb *x_in,
ec_nistp_felem_limb *y_in,
ec_nistp_felem_limb *z_in)
static void generate_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *table,
const ec_nistp_felem_limb *x_in,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in)
{
const size_t felem_num_limbs = ctx->felem_num_limbs;
const size_t felem_num_bytes = felem_num_limbs * sizeof(ec_nistp_felem_limb);
Expand Down Expand Up @@ -343,3 +360,130 @@ void generate_table(const ec_nistp_meth *ctx,
}
}

// Writes to xyz_out the idx-th point from table in constant-time.
static void select_point_from_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *xyz_out,
const ec_nistp_felem_limb *table,
const size_t idx) {
size_t entry_size = 3 * ctx->felem_num_limbs * sizeof(ec_nistp_felem_limb);

constant_time_select_entry_from_table_8(
(uint8_t*)xyz_out, (uint8_t*)table,
idx, SCALAR_MUL_TABLE_NUM_POINTS, entry_size);
}

// Multiplication of an arbitrary point by a scalar, r = [scalar]P.
// The product is computed with the use of a small table generated on-the-fly
// and the scalar recoded in the regular-wNAF representation.
//
// The precomputed (on-the-fly) table |table| holds odd multiples of P:
// [2i + 1]P for i in [0, SCALAR_MUL_TABLE_NUM_POINTS - 1].
// Computing the negation of a point P = (x, y, z) is relatively easy:
// -P = (x, -y, z),
// so we may assume that for each point we have its negative as well.
//
// The scalar is recoded (regular-wNAF encoding) into signed digits as explained
// in |scalar_rwnaf| function. Namely, for a window size |w| we have:
// scalar' = s_0 + s_1*2^w + s_2*2^(2*w) + ... + s_{m-1}*2^((m-1)*w),
// where digits s_i are in [\pm 1, \pm 3, ..., \pm (2^w-1)] and
// m = ceil(scalar_bit_size / w). Note that for an odd scalar we have that
// scalar = scalar', while in the case of an even scalar we have that
// scalar = scalar' - 1.
//
// The required product, [scalar]P, is computed by the following algorithm.
// 1. Initialize the accumulator with the point from |table|
// corresponding to the most significant digit s_{m-1} of the scalar.
// 2. For digits s_i starting from s_{m-2} down to s_0:
// 3. Double the accumulator w times. (note that doubling a point [a]P
// w times results in [2^w*a]P).
// 4. Read from |table| the point corresponding to abs(s_i),
// negate it if s_i is negative, and add it to the accumulator.
// 5. Subtract P from the result if the scalar is even.
//
// Note: this function is constant-time.
void ec_nistp_scalar_mul(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
ec_nistp_felem_limb *z_out,
const ec_nistp_felem_limb *x_in,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in,
const EC_SCALAR *scalar) {
// Make sure that the max table size is large enough.
assert(SCALAR_MUL_TABLE_MAX_NUM_FELEM_LIMBS >=
SCALAR_MUL_TABLE_NUM_POINTS * ctx->felem_num_limbs * 3);

// Table of multiples of P = (x_in, y_in, z_in).
ec_nistp_felem_limb table[SCALAR_MUL_TABLE_MAX_NUM_FELEM_LIMBS];
generate_table(ctx, table, x_in, y_in, z_in);

// Regular-wNAF encoding of the scalar.
int16_t rwnaf[SCALAR_MUL_MAX_NUM_WINDOWS];
scalar_rwnaf(rwnaf, SCALAR_MUL_WINDOW_SIZE, scalar, ctx->felem_num_bits);

// We need two point accumulators, so we define them of maximum size
// to avoid allocation, and just take pointers to individual coordinates.
// (This cruft will dissapear when we refactor point_add/dbl to work with
// whole points instead of individual coordinates).
ec_nistp_felem_limb res[3 * FELEM_MAX_NUM_OF_LIMBS];
ec_nistp_felem_limb tmp[3 * FELEM_MAX_NUM_OF_LIMBS];
ec_nistp_felem_limb *x_res = &res[0];
ec_nistp_felem_limb *y_res = &res[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_res = &res[ctx->felem_num_limbs * 2];
ec_nistp_felem_limb *x_tmp = &tmp[0];
ec_nistp_felem_limb *y_tmp = &tmp[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_tmp = &tmp[ctx->felem_num_limbs * 2];

// The actual number of windows (digits) of the scalar (denoted by m in the
// description above the function).
const size_t num_windows = DIV_AND_CEIL(ctx->felem_num_bits, SCALAR_MUL_WINDOW_SIZE);

// Step 1. Initialize the accmulator (res) with the input point multiplied by
// the most significant digit of the scalar s_{m-1} (note that this digit
// can't be negative).
int16_t idx = rwnaf[num_windows - 1];
idx >>= 1;
select_point_from_table(ctx, res, table, idx);

// Step 2. Process the remaining digits of the scalar (s_{m-2} to s_0).
for (int i = num_windows - 2; i >= 0; i--) {
// Step 3. Double the accumulator w times.
for (size_t j = 0; j < SCALAR_MUL_WINDOW_SIZE; j++) {
ctx->point_dbl(x_res, y_res, z_res, x_res, y_res, z_res);
}

// Step 4a. Compute abs(s_i).
int16_t d = rwnaf[i];
int16_t is_neg = (d >> 15) & 1; // is_neg = (d < 0) ? 1 : 0
d = (d ^ -is_neg) + is_neg; // d = abs(d)

// Step 4b. Select from table the point corresponding to abs(s_i).
idx = d >> 1;
select_point_from_table(ctx, tmp, table, idx);

// Step 4c. Negate the point if s_i < 0.
ec_nistp_felem ftmp;
ctx->felem_neg(ftmp, y_tmp);

cmovznz(y_tmp, ctx->felem_num_limbs, is_neg, y_tmp, ftmp);

// Step 4d. Add the point to the accumulator.
ctx->point_add(x_res, y_res, z_res, x_res, y_res, z_res, 0, x_tmp, y_tmp, z_tmp);
}

// Step 5a. Negate the input point P (we negate it in-place since we already
// have it stored as the first entry in the table).
ec_nistp_felem_limb *x_mp = &table[0];
ec_nistp_felem_limb *y_mp = &table[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_mp = &table[ctx->felem_num_limbs * 2];
ctx->felem_neg(y_mp, y_mp);

// Step 5b. Subtract P from the accumulator.
ctx->point_add(x_tmp, y_tmp, z_tmp, x_res, y_res, z_res, 0, x_mp, y_mp, z_mp);

// Step 5c. Select |res| or |res - P| based on parity of the scalar.
ec_nistp_felem_limb t = scalar->words[0] & 1;
cmovznz(x_out, ctx->felem_num_limbs, t, x_tmp, x_res);
cmovznz(y_out, ctx->felem_num_limbs, t, y_tmp, y_res);
cmovznz(z_out, ctx->felem_num_limbs, t, z_tmp, z_res);
}
25 changes: 10 additions & 15 deletions crypto/fipsmodule/ec/ec_nistp.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ typedef uint32_t ec_nistp_felem_limb;
// providing an appropriate methods object.
typedef struct {
size_t felem_num_limbs;
size_t felem_num_bits;
void (*felem_add)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a, const ec_nistp_felem_limb *b);
void (*felem_sub)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a, const ec_nistp_felem_limb *b);
void (*felem_mul)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a, const ec_nistp_felem_limb *b);
void (*felem_sqr)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a);
void (*felem_neg)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a);
ec_nistp_felem_limb (*felem_nz)(const ec_nistp_felem_limb *a);

void (*point_dbl)(ec_nistp_felem_limb *x_out,
Expand Down Expand Up @@ -96,20 +98,13 @@ void ec_nistp_point_add(const ec_nistp_meth *ctx,
const ec_nistp_felem_limb *y2,
const ec_nistp_felem_limb *z2);

// These two functions and two macros are temporarily defined here.
// They will be moved to ec_nistp.c as static function
// once all the scalar multiplications are implemented.
void scalar_rwnaf(int16_t *out, size_t window_size,
const EC_SCALAR *scalar, size_t scalar_bit_size);
void generate_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *table,
ec_nistp_felem_limb *x_in,
ec_nistp_felem_limb *y_in,
ec_nistp_felem_limb *z_in);

// The window size for scalar multiplication is hard coded for now.
#define SCALAR_MUL_WINDOW_SIZE (5)
#define SCALAR_MUL_TABLE_NUM_POINTS (1 << (SCALAR_MUL_WINDOW_SIZE - 1))

void ec_nistp_scalar_mul(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
ec_nistp_felem_limb *z_out,
const ec_nistp_felem_limb *x_in,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in,
const EC_SCALAR *scalar);
#endif // EC_NISTP_H

84 changes: 10 additions & 74 deletions crypto/fipsmodule/ec/p256.c
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ static void fiat_p256_point_add(fiat_p256_felem x3, fiat_p256_felem y3,

DEFINE_METHOD_FUNCTION(ec_nistp_meth, p256_methods) {
out->felem_num_limbs = FIAT_P256_NLIMBS;
out->felem_num_bits = 256;
out->felem_add = fiat_p256_add;
out->felem_sub = fiat_p256_sub;
out->felem_mul = fiat_p256_mul;
out->felem_sqr = fiat_p256_square;
out->felem_neg = fiat_p256_opp;
out->felem_nz = fiat_p256_nz;
out->point_dbl = fiat_p256_point_double;
out->point_add = fiat_p256_point_add;
Expand All @@ -214,20 +216,6 @@ static void fiat_p256_select_point_affine(
fiat_p256_cmovznz(out[2], idx, out[2], fiat_p256_one);
}

// fiat_p256_select_point selects the |idx|th point from a precomputation table
// and copies it to out.
static void fiat_p256_select_point(const fiat_p256_limb_t idx, size_t size,
const fiat_p256_felem pre_comp[/*size*/][3],
fiat_p256_felem out[3]) {
OPENSSL_memset(out, 0, sizeof(fiat_p256_felem) * 3);
for (size_t i = 0; i < size; i++) {
fiat_p256_limb_t mismatch = i ^ idx;
fiat_p256_cmovznz(out[0], mismatch, pre_comp[i][0], out[0]);
fiat_p256_cmovznz(out[1], mismatch, pre_comp[i][1], out[1]);
fiat_p256_cmovznz(out[2], mismatch, pre_comp[i][2], out[2]);
}
}

// fiat_p256_get_bit returns the |i|th bit in |in|.
static crypto_word_t fiat_p256_get_bit(const EC_SCALAR *in, int i) {
if (i < 0 || i >= 256) {
Expand Down Expand Up @@ -309,68 +297,16 @@ static void ec_GFp_nistp256_dbl(const EC_GROUP *group, EC_JACOBIAN *r,
static void ec_GFp_nistp256_point_mul(const EC_GROUP *group, EC_JACOBIAN *r,
const EC_JACOBIAN *p,
const EC_SCALAR *scalar) {
fiat_p256_felem p_pre_comp[17][3];
OPENSSL_memset(&p_pre_comp, 0, sizeof(p_pre_comp));
// Precompute multiples.
fiat_p256_from_generic(p_pre_comp[1][0], &p->X);
fiat_p256_from_generic(p_pre_comp[1][1], &p->Y);
fiat_p256_from_generic(p_pre_comp[1][2], &p->Z);
for (size_t j = 2; j <= 16; ++j) {
if (j & 1) {
fiat_p256_point_add(p_pre_comp[j][0], p_pre_comp[j][1], p_pre_comp[j][2],
p_pre_comp[1][0], p_pre_comp[1][1], p_pre_comp[1][2],
0, p_pre_comp[j - 1][0], p_pre_comp[j - 1][1],
p_pre_comp[j - 1][2]);
} else {
fiat_p256_point_double(p_pre_comp[j][0], p_pre_comp[j][1],
p_pre_comp[j][2], p_pre_comp[j / 2][0],
p_pre_comp[j / 2][1], p_pre_comp[j / 2][2]);
}
}
fiat_p256_felem res[3], tmp[3];
fiat_p256_from_generic(tmp[0], &p->X);
fiat_p256_from_generic(tmp[1], &p->Y);
fiat_p256_from_generic(tmp[2], &p->Z);

// Set nq to the point at infinity.
fiat_p256_felem nq[3] = {{0}, {0}, {0}}, ftmp, tmp[3];
ec_nistp_scalar_mul(p256_methods(), res[0], res[1], res[2], tmp[0], tmp[1], tmp[2], scalar);

// Loop over |scalar| msb-to-lsb, incorporating |p_pre_comp| every 5th round.
int skip = 1; // Save two point operations in the first round.
for (size_t i = 255; i < 256; i--) {
// double
if (!skip) {
fiat_p256_point_double(nq[0], nq[1], nq[2], nq[0], nq[1], nq[2]);
}

// do other additions every 5 doublings
if (i % 5 == 0) {
crypto_word_t bits = fiat_p256_get_bit(scalar, i + 4) << 5;
bits |= fiat_p256_get_bit(scalar, i + 3) << 4;
bits |= fiat_p256_get_bit(scalar, i + 2) << 3;
bits |= fiat_p256_get_bit(scalar, i + 1) << 2;
bits |= fiat_p256_get_bit(scalar, i) << 1;
bits |= fiat_p256_get_bit(scalar, i - 1);
crypto_word_t sign, digit;
ec_GFp_nistp_recode_scalar_bits(&sign, &digit, bits);

// select the point to add or subtract, in constant time.
fiat_p256_select_point((fiat_p256_limb_t)digit, 17,
(const fiat_p256_felem(*)[3])p_pre_comp, tmp);
fiat_p256_opp(ftmp, tmp[1]); // (X, -Y, Z) is the negative point.
fiat_p256_cmovznz(tmp[1], (fiat_p256_limb_t)sign, tmp[1], ftmp);

if (!skip) {
fiat_p256_point_add(nq[0], nq[1], nq[2], nq[0], nq[1], nq[2],
0 /* mixed */, tmp[0], tmp[1], tmp[2]);
} else {
fiat_p256_copy(nq[0], tmp[0]);
fiat_p256_copy(nq[1], tmp[1]);
fiat_p256_copy(nq[2], tmp[2]);
skip = 0;
}
}
}

fiat_p256_to_generic(&r->X, nq[0]);
fiat_p256_to_generic(&r->Y, nq[1]);
fiat_p256_to_generic(&r->Z, nq[2]);
fiat_p256_to_generic(&r->X, res[0]);
fiat_p256_to_generic(&r->Y, res[1]);
fiat_p256_to_generic(&r->Z, res[2]);
}

static void ec_GFp_nistp256_point_mul_base(const EC_GROUP *group,
Expand Down
Loading

0 comments on commit 9431f99

Please sign in to comment.