Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster CKKS multiply #346

Merged
merged 4 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ The optional dependencies and their tested versions (other versions may work as

#### Intel HEXL

Intel HEXL is a library providing efficient implementations of cryptographic primitives common in homomorphic encryption. The acceleration is particularly evident on Intel processors with the Intel AVX512-IMA52 instruction set.
Intel HEXL is a library providing efficient implementations of cryptographic primitives common in homomorphic encryption. The acceleration is particularly evident on Intel processors with the Intel AVX512-IFMA52 instruction set.

#### Microsoft GSL

Expand Down
90 changes: 80 additions & 10 deletions native/src/seal/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ namespace seal
#endif
}

void Evaluator::bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
void Evaluator::bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
{
if (encrypted1.is_ntt_form() || encrypted2.is_ntt_form())
{
Expand Down Expand Up @@ -506,6 +506,74 @@ namespace seal
// Prepare destination
encrypted1.resize(context_, context_data.parms_id(), dest_size);

if (dest_size == 3)
{
// We want to keep six polynomials in the L1 cache: x[0], x[1], x[2], y[0], y[1], temp.
// For a 32KiB cache, which can store 32768 / 8 = 4096 coefficients, = 682.67 coefficients per polynomial,
// we should keep the tile size at 682 or below. The tile size must divide coeff_count, i.e. be a power of
// two. Some testing shows similar performance with tile size 256 and 512, and worse performance on smaller
// tiles. We pick the smaller of the two to prevent L1 cache misses on processors with < 32 KiB L1 cache.
size_t tile_size = min<size_t>(coeff_count, size_t(256));
size_t num_tiles = coeff_count / tile_size;
#ifdef SEAL_DEBUG
if (coeff_count % tile_size != 0)
{
throw invalid_argument("tile_size does not divide coeff_count");
}
#endif

// Set up iterators for input ciphertexts
PolyIter encrypted1_iter = iter(encrypted1);
ConstPolyIter encrypted2_iter = iter(encrypted2);

// Semantic misuse of RNSIter; each is really pointing to the data for each RNS factor in sequence
ConstRNSIter encrypted2_0_iter(*encrypted2_iter[0], tile_size);
ConstRNSIter encrypted2_1_iter(*encrypted2_iter[1], tile_size);
RNSIter encrypted1_0_iter(*encrypted1_iter[0], tile_size);
RNSIter encrypted1_1_iter(*encrypted1_iter[1], tile_size);
RNSIter encrypted1_2_iter(*encrypted1_iter[2], tile_size);

// Temporary buffer to store intermediate results
SEAL_ALLOCATE_GET_COEFF_ITER(temp, tile_size, pool);

// Computes the output tile_size coefficients at a time
// Given input tuples of polynomials x = (x[0], x[1], x[2]), y = (y[0], y[1]), computes
// x = (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1])
// with appropriate modular reduction
SEAL_ITERATE(coeff_modulus, coeff_modulus_size, [&](auto I) {
SEAL_ITERATE(iter(size_t(0)), num_tiles, [&](auto J) {
// Compute third output polynomial, overwriting input
// x[2] = x[1] * y[1]
dyadic_product_coeffmod(
encrypted1_1_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_2_iter[0]);

// Compute second output polynomial, overwriting input
// temp = x[1] * y[0]
dyadic_product_coeffmod(encrypted1_1_iter[0], encrypted2_0_iter[0], tile_size, I, temp);
// x[1] = x[0] * y[1]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_1_iter[0]);
// x[1] += temp
add_poly_coeffmod(encrypted1_1_iter[0], temp, tile_size, I, encrypted1_1_iter[0]);

// Compute first output polynomial, overwriting input
// x[0] = x[0] * y[0]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_0_iter[0], tile_size, I, encrypted1_0_iter[0]);

// Manually increment iterators
++encrypted1_0_iter;
++encrypted1_1_iter;
++encrypted1_2_iter;
++encrypted2_0_iter;
++encrypted2_1_iter;
});
});

encrypted1.scale() = new_scale;
return;
}

// Set up iterators for input ciphertexts
auto encrypted1_iter = iter(encrypted1);
auto encrypted2_iter = iter(encrypted2);
Expand Down Expand Up @@ -921,7 +989,8 @@ namespace seal
}
}

void Evaluator::mod_switch_drop_to_next(const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
void Evaluator::mod_switch_drop_to_next(
const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
{
// Assuming at this point encrypted is already validated.
auto context_data_ptr = context_.get_context_data(encrypted.parms_id());
Expand Down Expand Up @@ -1020,7 +1089,8 @@ namespace seal
plain.parms_id() = next_context_data.parms_id();
}

void Evaluator::mod_switch_to_next(const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
void Evaluator::mod_switch_to_next(
const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
Expand Down Expand Up @@ -1627,7 +1697,7 @@ namespace seal
encrypted.scale() = new_scale;
}

void Evaluator::multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const
void Evaluator::multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const
{
// Verify parameters.
if (!plain_ntt.is_ntt_form())
Expand Down Expand Up @@ -1668,7 +1738,7 @@ namespace seal
encrypted_ntt.scale() = new_scale;
}

void Evaluator::transform_to_ntt_inplace(Plaintext &plain, parms_id_type parms_id, MemoryPoolHandle pool) const
void Evaluator::transform_to_ntt_inplace(Plaintext &plain, parms_id_type parms_id, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_valid_for(plain, context_))
Expand Down Expand Up @@ -1761,7 +1831,7 @@ namespace seal
plain.parms_id() = parms_id;
}

void Evaluator::transform_to_ntt_inplace(Ciphertext &encrypted) const
void Evaluator::transform_to_ntt_inplace(Ciphertext &encrypted) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
Expand Down Expand Up @@ -1809,7 +1879,7 @@ namespace seal
#endif
}

void Evaluator::transform_from_ntt_inplace(Ciphertext &encrypted_ntt) const
void Evaluator::transform_from_ntt_inplace(Ciphertext &encrypted_ntt) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted_ntt, context_) || !is_buffer_valid(encrypted_ntt))
Expand Down Expand Up @@ -1857,7 +1927,7 @@ namespace seal
}

void Evaluator::apply_galois_inplace(
Ciphertext &encrypted, uint32_t galois_elt, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
Ciphertext &encrypted, uint32_t galois_elt, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
Expand Down Expand Up @@ -1961,7 +2031,7 @@ namespace seal
}

void Evaluator::rotate_internal(
Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
{
auto context_data_ptr = context_.get_context_data(encrypted.parms_id());
if (!context_data_ptr)
Expand Down Expand Up @@ -2019,7 +2089,7 @@ namespace seal

void Evaluator::switch_key_inplace(
Ciphertext &encrypted, ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, size_t kswitch_keys_index,
MemoryPoolHandle pool) const
MemoryPoolHandle pool) const
{
auto parms_id = encrypted.parms_id();
auto &context_data = *context_.get_context_data(parms_id);
Expand Down