Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EC] Unify scalar multiplication for P-256/384/521 #1693

Merged
merged 18 commits into from
Jul 17, 2024
Merged
166 changes: 155 additions & 11 deletions crypto/fipsmodule/ec/ec_nistp.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// | 1. | x | x | x* |
// | 2. | x | x | x* |
// | 3. | | | |
// | 4. | | | |
// | 4. | x | x | x* |
// | 5. | | | |
// * For P-256, only the Fiat-crypto implementation in p256.c is replaced.

Expand All @@ -30,11 +30,11 @@
// for the moment, this will be fixed when we migrate the whole P-521
// implementation to ec_nistp.c.
#if defined(EC_NISTP_USE_64BIT_LIMB)
#define NISTP_FELEM_MAX_NUM_OF_LIMBS (9)
#define FELEM_MAX_NUM_OF_LIMBS (9)
#else
#define NISTP_FELEM_MAX_NUM_OF_LIMBS (19)
#define FELEM_MAX_NUM_OF_LIMBS (19)
#endif
typedef ec_nistp_felem_limb ec_nistp_felem[NISTP_FELEM_MAX_NUM_OF_LIMBS];
typedef ec_nistp_felem_limb ec_nistp_felem[FELEM_MAX_NUM_OF_LIMBS];

// Conditional copy in constant-time (out = t == 0 ? z : nz).
static void cmovznz(ec_nistp_felem_limb *out,
Expand Down Expand Up @@ -280,8 +280,8 @@ static int16_t get_bit(const EC_SCALAR *in, size_t i) {
// It forces an odd scalar and outputs digits in
// {\pm 1, \pm 3, \pm 5, \pm 7, \pm 9, ...}
// i.e. signed odd digits with _no zeroes_ -- that makes it "regular".
void scalar_rwnaf(int16_t *out, size_t window_size,
const EC_SCALAR *scalar, size_t scalar_bit_size) {
static void scalar_rwnaf(int16_t *out, size_t window_size,
const EC_SCALAR *scalar, size_t scalar_bit_size) {
assert(window_size < 14);

// The assert above ensures this works correctly.
Expand All @@ -304,13 +304,30 @@ void scalar_rwnaf(int16_t *out, size_t window_size,
out[num_windows - 1] = window;
}

// The window size for scalar multiplication is hard coded for now.
#define SCALAR_MUL_WINDOW_SIZE (5)
#define SCALAR_MUL_TABLE_NUM_POINTS (1 << (SCALAR_MUL_WINDOW_SIZE - 1))

// To avoid dynamic allocation and freeing of memory in functions below
// we define maximum values of certain variables.
//
// The maximum number of limbs the table in |ec_nistp_scalar_mul| can have.
// Each point in the table has 3 coordinates that are field elements,
// and each field element has a defined maximum number of limbs.
#define SCALAR_MUL_TABLE_MAX_NUM_FELEM_LIMBS \
(SCALAR_MUL_TABLE_NUM_POINTS * 3 * FELEM_MAX_NUM_OF_LIMBS)

// Maximum number of windows (digits) for a scalar encoding which is
// determined by the maximum scalar bit size -- 521 bits in our case.
#define SCALAR_MUL_MAX_NUM_WINDOWS DIV_AND_CEIL(521, SCALAR_MUL_WINDOW_SIZE)

// Generate table of multiples of the input point P = (x_in, y_in, z_in):
// table <-- [2i + 1]P for i in [0, SCALAR_MUL_TABLE_NUM_POINTS - 1].
void generate_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *table,
ec_nistp_felem_limb *x_in,
ec_nistp_felem_limb *y_in,
ec_nistp_felem_limb *z_in)
static void generate_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *table,
const ec_nistp_felem_limb *x_in,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in)
{
const size_t felem_num_limbs = ctx->felem_num_limbs;
const size_t felem_num_bytes = felem_num_limbs * sizeof(ec_nistp_felem_limb);
Expand Down Expand Up @@ -343,3 +360,130 @@ void generate_table(const ec_nistp_meth *ctx,
}
}

// Writes to xyz_out the idx-th point from table in constant-time.
static void select_point_from_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *xyz_out,
const ec_nistp_felem_limb *table,
const size_t idx) {
size_t entry_size = 3 * ctx->felem_num_limbs * sizeof(ec_nistp_felem_limb);

constant_time_select_entry_from_table_8(
(uint8_t*)xyz_out, (uint8_t*)table,
idx, SCALAR_MUL_TABLE_NUM_POINTS, entry_size);
}

// Multiplication of an arbitrary point by a scalar, r = [scalar]P.
// The product is computed with the use of a small table generated on-the-fly
dkostic marked this conversation as resolved.
Show resolved Hide resolved
// and the scalar recoded in the regular-wNAF representation.
//
// The precomputed (on-the-fly) table |table| holds odd multiples of P:
// [2i + 1]P for i in [0, SCALAR_MUL_TABLE_NUM_POINTS - 1].
// Computing the negation of a point P = (x, y, z) is relatively easy:
// -P = (x, -y, z),
// so we may assume that for each point we have its negative as well.
//
// The scalar is recoded (regular-wNAF encoding) into signed digits as explained
// in |scalar_rwnaf| function. Namely, for a window size |w| we have:
// scalar' = s_0 + s_1*2^w + s_2*2^(2*w) + ... + s_{m-1}*2^((m-1)*w),
// where digits s_i are in [\pm 1, \pm 3, ..., \pm (2^w-1)] and
// m = ceil(scalar_bit_size / w). Note that for an odd scalar we have that
// scalar = scalar', while in the case of an even scalar we have that
// scalar = scalar' - 1.
//
// The required product, [scalar]P, is computed by the following algorithm.
// 1. Initialize the accumulator with the point from |table|
// corresponding to the most significant digit s_{m-1} of the scalar.
// 2. For digits s_i starting from s_{m-2} down to s_0:
// 3. Double the accumulator w times. (note that doubling a point [a]P
// w times results in [2^w*a]P).
// 4. Read from |table| the point corresponding to abs(s_i),
// negate it if s_i is negative, and add it to the accumulator.
// 5. Subtract P from the result if the scalar is even.
//
// Note: this function is constant-time.
void ec_nistp_scalar_mul(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
ec_nistp_felem_limb *z_out,
const ec_nistp_felem_limb *x_in,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in,
const EC_SCALAR *scalar) {
// Make sure that the max table size is large enough.
assert(SCALAR_MUL_TABLE_MAX_NUM_FELEM_LIMBS >=
SCALAR_MUL_TABLE_NUM_POINTS * ctx->felem_num_limbs * 3);

// Table of multiples of P = (x_in, y_in, z_in).
ec_nistp_felem_limb table[SCALAR_MUL_TABLE_MAX_NUM_FELEM_LIMBS];
dkostic marked this conversation as resolved.
Show resolved Hide resolved
generate_table(ctx, table, x_in, y_in, z_in);

// Regular-wNAF encoding of the scalar.
int16_t rwnaf[SCALAR_MUL_MAX_NUM_WINDOWS];
scalar_rwnaf(rwnaf, SCALAR_MUL_WINDOW_SIZE, scalar, ctx->felem_num_bits);

// We need two point accumulators, so we define them of maximum size
// to avoid allocation, and just take pointers to individual coordinates.
// (This cruft will dissapear when we refactor point_add/dbl to work with
// whole points instead of individual coordinates).
ec_nistp_felem_limb res[3 * FELEM_MAX_NUM_OF_LIMBS];
ec_nistp_felem_limb tmp[3 * FELEM_MAX_NUM_OF_LIMBS];
ec_nistp_felem_limb *x_res = &res[0];
ec_nistp_felem_limb *y_res = &res[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_res = &res[ctx->felem_num_limbs * 2];
ec_nistp_felem_limb *x_tmp = &tmp[0];
ec_nistp_felem_limb *y_tmp = &tmp[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_tmp = &tmp[ctx->felem_num_limbs * 2];

// The actual number of windows (digits) of the scalar (denoted by m in the
// description above the function).
const size_t num_windows = DIV_AND_CEIL(ctx->felem_num_bits, SCALAR_MUL_WINDOW_SIZE);

// Step 1. Initialize the accmulator (res) with the input point multiplied by
// the most significant digit of the scalar s_{m-1} (note that this digit
// can't be negative).
int16_t idx = rwnaf[num_windows - 1];
idx >>= 1;
select_point_from_table(ctx, res, table, idx);

// Step 2. Process the remaining digits of the scalar (s_{m-2} to s_0).
for (int i = num_windows - 2; i >= 0; i--) {
// Step 3. Double the accumulator w times.
for (size_t j = 0; j < SCALAR_MUL_WINDOW_SIZE; j++) {
ctx->point_dbl(x_res, y_res, z_res, x_res, y_res, z_res);
}

// Step 4a. Compute abs(s_i).
int16_t d = rwnaf[i];
int16_t is_neg = (d >> 15) & 1; // is_neg = (d < 0) ? 1 : 0
d = (d ^ -is_neg) + is_neg; // d = abs(d)

// Step 4b. Select from table the point corresponding to abs(s_i).
idx = d >> 1;
select_point_from_table(ctx, tmp, table, idx);

// Step 4c. Negate the point if s_i < 0.
ec_nistp_felem ftmp;
ctx->felem_neg(ftmp, y_tmp);

cmovznz(y_tmp, ctx->felem_num_limbs, is_neg, y_tmp, ftmp);

// Step 4d. Add the point to the accumulator.
ctx->point_add(x_res, y_res, z_res, x_res, y_res, z_res, 0, x_tmp, y_tmp, z_tmp);
}

// Step 5a. Negate the input point P (we negate it in-place since we already
// have it stored as the first entry in the table).
ec_nistp_felem_limb *x_mp = &table[0];
ec_nistp_felem_limb *y_mp = &table[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_mp = &table[ctx->felem_num_limbs * 2];
ctx->felem_neg(y_mp, y_mp);

// Step 5b. Subtract P from the accumulator.
ctx->point_add(x_tmp, y_tmp, z_tmp, x_res, y_res, z_res, 0, x_mp, y_mp, z_mp);

// Step 5c. Select |res| or |res - P| based on parity of the scalar.
ec_nistp_felem_limb t = scalar->words[0] & 1;
cmovznz(x_out, ctx->felem_num_limbs, t, x_tmp, x_res);
cmovznz(y_out, ctx->felem_num_limbs, t, y_tmp, y_res);
cmovznz(z_out, ctx->felem_num_limbs, t, z_tmp, z_res);
}
25 changes: 10 additions & 15 deletions crypto/fipsmodule/ec/ec_nistp.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ typedef uint32_t ec_nistp_felem_limb;
// providing an appropriate methods object.
typedef struct {
size_t felem_num_limbs;
size_t felem_num_bits;
void (*felem_add)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a, const ec_nistp_felem_limb *b);
void (*felem_sub)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a, const ec_nistp_felem_limb *b);
void (*felem_mul)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a, const ec_nistp_felem_limb *b);
void (*felem_sqr)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a);
void (*felem_neg)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a);
ec_nistp_felem_limb (*felem_nz)(const ec_nistp_felem_limb *a);

void (*point_dbl)(ec_nistp_felem_limb *x_out,
Expand Down Expand Up @@ -96,20 +98,13 @@ void ec_nistp_point_add(const ec_nistp_meth *ctx,
const ec_nistp_felem_limb *y2,
const ec_nistp_felem_limb *z2);

// These two functions and two macros are temporarily defined here.
// They will be moved to ec_nistp.c as static function
// once all the scalar multiplications are implemented.
void scalar_rwnaf(int16_t *out, size_t window_size,
const EC_SCALAR *scalar, size_t scalar_bit_size);
void generate_table(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *table,
ec_nistp_felem_limb *x_in,
ec_nistp_felem_limb *y_in,
ec_nistp_felem_limb *z_in);

// The window size for scalar multiplication is hard coded for now.
#define SCALAR_MUL_WINDOW_SIZE (5)
#define SCALAR_MUL_TABLE_NUM_POINTS (1 << (SCALAR_MUL_WINDOW_SIZE - 1))

void ec_nistp_scalar_mul(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
ec_nistp_felem_limb *z_out,
const ec_nistp_felem_limb *x_in,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in,
const EC_SCALAR *scalar);
#endif // EC_NISTP_H

84 changes: 10 additions & 74 deletions crypto/fipsmodule/ec/p256.c
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ static void fiat_p256_point_add(fiat_p256_felem x3, fiat_p256_felem y3,

DEFINE_METHOD_FUNCTION(ec_nistp_meth, p256_methods) {
out->felem_num_limbs = FIAT_P256_NLIMBS;
out->felem_num_bits = 256;
out->felem_add = fiat_p256_add;
out->felem_sub = fiat_p256_sub;
out->felem_mul = fiat_p256_mul;
out->felem_sqr = fiat_p256_square;
out->felem_neg = fiat_p256_opp;
out->felem_nz = fiat_p256_nz;
out->point_dbl = fiat_p256_point_double;
out->point_add = fiat_p256_point_add;
Expand All @@ -214,20 +216,6 @@ static void fiat_p256_select_point_affine(
fiat_p256_cmovznz(out[2], idx, out[2], fiat_p256_one);
}

// fiat_p256_select_point selects the |idx|th point from a precomputation table
// and copies it to out.
static void fiat_p256_select_point(const fiat_p256_limb_t idx, size_t size,
const fiat_p256_felem pre_comp[/*size*/][3],
fiat_p256_felem out[3]) {
OPENSSL_memset(out, 0, sizeof(fiat_p256_felem) * 3);
for (size_t i = 0; i < size; i++) {
fiat_p256_limb_t mismatch = i ^ idx;
fiat_p256_cmovznz(out[0], mismatch, pre_comp[i][0], out[0]);
fiat_p256_cmovznz(out[1], mismatch, pre_comp[i][1], out[1]);
fiat_p256_cmovznz(out[2], mismatch, pre_comp[i][2], out[2]);
}
}

// fiat_p256_get_bit returns the |i|th bit in |in|.
static crypto_word_t fiat_p256_get_bit(const EC_SCALAR *in, int i) {
if (i < 0 || i >= 256) {
Expand Down Expand Up @@ -309,68 +297,16 @@ static void ec_GFp_nistp256_dbl(const EC_GROUP *group, EC_JACOBIAN *r,
static void ec_GFp_nistp256_point_mul(const EC_GROUP *group, EC_JACOBIAN *r,
const EC_JACOBIAN *p,
const EC_SCALAR *scalar) {
fiat_p256_felem p_pre_comp[17][3];
OPENSSL_memset(&p_pre_comp, 0, sizeof(p_pre_comp));
// Precompute multiples.
fiat_p256_from_generic(p_pre_comp[1][0], &p->X);
fiat_p256_from_generic(p_pre_comp[1][1], &p->Y);
fiat_p256_from_generic(p_pre_comp[1][2], &p->Z);
for (size_t j = 2; j <= 16; ++j) {
if (j & 1) {
fiat_p256_point_add(p_pre_comp[j][0], p_pre_comp[j][1], p_pre_comp[j][2],
p_pre_comp[1][0], p_pre_comp[1][1], p_pre_comp[1][2],
0, p_pre_comp[j - 1][0], p_pre_comp[j - 1][1],
p_pre_comp[j - 1][2]);
} else {
fiat_p256_point_double(p_pre_comp[j][0], p_pre_comp[j][1],
p_pre_comp[j][2], p_pre_comp[j / 2][0],
p_pre_comp[j / 2][1], p_pre_comp[j / 2][2]);
}
}
fiat_p256_felem res[3], tmp[3];
justsmth marked this conversation as resolved.
Show resolved Hide resolved
fiat_p256_from_generic(tmp[0], &p->X);
fiat_p256_from_generic(tmp[1], &p->Y);
fiat_p256_from_generic(tmp[2], &p->Z);

// Set nq to the point at infinity.
fiat_p256_felem nq[3] = {{0}, {0}, {0}}, ftmp, tmp[3];
ec_nistp_scalar_mul(p256_methods(), res[0], res[1], res[2], tmp[0], tmp[1], tmp[2], scalar);

// Loop over |scalar| msb-to-lsb, incorporating |p_pre_comp| every 5th round.
int skip = 1; // Save two point operations in the first round.
for (size_t i = 255; i < 256; i--) {
// double
if (!skip) {
fiat_p256_point_double(nq[0], nq[1], nq[2], nq[0], nq[1], nq[2]);
}

// do other additions every 5 doublings
if (i % 5 == 0) {
crypto_word_t bits = fiat_p256_get_bit(scalar, i + 4) << 5;
bits |= fiat_p256_get_bit(scalar, i + 3) << 4;
bits |= fiat_p256_get_bit(scalar, i + 2) << 3;
bits |= fiat_p256_get_bit(scalar, i + 1) << 2;
bits |= fiat_p256_get_bit(scalar, i) << 1;
bits |= fiat_p256_get_bit(scalar, i - 1);
crypto_word_t sign, digit;
ec_GFp_nistp_recode_scalar_bits(&sign, &digit, bits);

// select the point to add or subtract, in constant time.
fiat_p256_select_point((fiat_p256_limb_t)digit, 17,
(const fiat_p256_felem(*)[3])p_pre_comp, tmp);
fiat_p256_opp(ftmp, tmp[1]); // (X, -Y, Z) is the negative point.
fiat_p256_cmovznz(tmp[1], (fiat_p256_limb_t)sign, tmp[1], ftmp);

if (!skip) {
fiat_p256_point_add(nq[0], nq[1], nq[2], nq[0], nq[1], nq[2],
0 /* mixed */, tmp[0], tmp[1], tmp[2]);
} else {
fiat_p256_copy(nq[0], tmp[0]);
fiat_p256_copy(nq[1], tmp[1]);
fiat_p256_copy(nq[2], tmp[2]);
skip = 0;
}
}
}

fiat_p256_to_generic(&r->X, nq[0]);
fiat_p256_to_generic(&r->Y, nq[1]);
fiat_p256_to_generic(&r->Z, nq[2]);
fiat_p256_to_generic(&r->X, res[0]);
fiat_p256_to_generic(&r->Y, res[1]);
fiat_p256_to_generic(&r->Z, res[2]);
}

static void ec_GFp_nistp256_point_mul_base(const EC_GROUP *group,
Expand Down
Loading
Loading