Skip to content

Commit

Permalink
update: task 1
Browse files Browse the repository at this point in the history
  • Loading branch information
jqxue1999 committed Feb 6, 2023
1 parent cbd36da commit 19e3c98
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 88 deletions.
138 changes: 64 additions & 74 deletions native/dev/BGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,47 @@
using namespace std;
using namespace seal;

class BGV
{
class BGV {
private:
int rol_a, col_a;
void print_matrix(vector<vector<uint64_t>> matrix, bool T = false)
{
if (T)
{
for (int i = 0; i < matrix[0].size(); i++)
{

void print_matrix(vector<vector<uint64_t>> matrix, bool T = false) {
if (T) {
for (int i = 0; i < matrix[0].size(); i++) {
for (int j = 0; j < matrix.size(); j++)
cout << matrix[j][i] << " ";
cout << endl;
}
cout << "==========================" << endl;
}
else
{
for (int i = 0; i < matrix.size(); i++)
{
} else {
for (int i = 0; i < matrix.size(); i++) {
for (int j = 0; j < matrix[i].size(); j++)
cout << matrix[i][j] << " ";
cout << endl;
}
cout << "==========================" << endl;
}
}
void print_vector(vector<uint64_t> vec)
{

void print_vector(vector<uint64_t> vec) {
for (int i = 0; i < vec.size(); i++)
cout << vec[i] << " ";
cout << endl << "==========================" << endl;
}
bool matrix_equal(vector<vector<uint64_t>> a, vector<vector<uint64_t>> b)
{

bool matrix_equal(vector<vector<uint64_t>> a, vector<vector<uint64_t>> b) {
if (a == b)
return true;
else
return false;
}
vector<vector<uint64_t>> matrix_mul(vector<vector<uint64_t>> a, vector<vector<uint64_t>> b)
{

vector<vector<uint64_t>> matrix_mul(vector<vector<uint64_t>> a, vector<vector<uint64_t>> b) {
int rol = a.size();
int col = b.size();
vector<vector<uint64_t>> res(rol, vector<uint64_t>(col));
for (int i = 0; i < rol; i++)
{
for (int j = 0; j < col; j++)
{
for (int i = 0; i < rol; i++) {
for (int j = 0; j < col; j++) {
int sum = 0;
for (int k = 0; k < a[0].size(); k++)
sum += a[i][k] * b[j][k];
Expand All @@ -67,23 +59,22 @@ class BGV
}
return res;
}
vector<vector<uint64_t>> init_message(int N, int M, int seed)
{

vector<vector<uint64_t>> init_message(int N, int M, int seed) {
srand(seed);
vector<vector<uint64_t>> res(N, vector<uint64_t>(M));
for (int i = 0; i < N; i++)
{
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++)
res[i][j] = rand() % 10 + 1;
}
return res;
}
Ciphertext matrix_to_ciphertext(vector<vector<uint64_t>> message, BatchEncoder &batch_encoder, Encryptor &encryptor)
{

Ciphertext
matrix_to_ciphertext(vector<vector<uint64_t>> message, BatchEncoder &batch_encoder, Encryptor &encryptor) {
// flatten 2d message to 1d
vector<uint64_t> message_1d;
for (int i = 0; i < message.size(); i++)
{
for (int i = 0; i < message.size(); i++) {
for (int j = 0; j < message[i].size(); j++)
message_1d.push_back(message[i][j]);
}
Expand All @@ -98,14 +89,13 @@ class BGV

return ciphertext_res;
}

Ciphertext server_compute_col(
Ciphertext ciphertext_a, vector<uint64_t> message_b_i, int rol_a, BatchEncoder &batch_encoder,
Evaluator &evaluator, GaloisKeys galois_keys)
{
Ciphertext ciphertext_a, vector<uint64_t> message_b_i, int rol_a, BatchEncoder &batch_encoder,
Evaluator &evaluator, GaloisKeys galois_keys) {
Ciphertext ciphertext_v = ciphertext_a;
vector<uint64_t> message_b_i_extend(rol_a * message_b_i.size(), 0);
for (int i = 0; i < rol_a; i++)
{
for (int i = 0; i < rol_a; i++) {
for (int j = 0; j < message_b_i.size(); j++)
message_b_i_extend[i * message_b_i.size() + j] = message_b_i[j];
}
Expand All @@ -118,20 +108,16 @@ class BGV
Ciphertext one;
Ciphertext ciphertext_v1 = ciphertext_v;
bool flag = false;
while (length > 1)
{
while (length > 1) {
Ciphertext ciphertext_v2 = ciphertext_v1;
evaluator.rotate_rows_inplace(ciphertext_v2, length / 2, galois_keys);
evaluator.add(ciphertext_v2, ciphertext_v1, ciphertext_v1);
if (length % 2)
{
if (length % 2) {
evaluator.rotate_rows_inplace(ciphertext_v2, length / 2, galois_keys);
if (!flag)
{
if (!flag) {
one = ciphertext_v2;
flag = true;
}
else
} else
evaluator.add(one, ciphertext_v2, one);
}
length /= 2;
Expand All @@ -140,22 +126,21 @@ class BGV
evaluator.add(one, ciphertext_v1, ciphertext_v1);
return ciphertext_v1;
}

vector<Ciphertext> server_compute(
Ciphertext ciphertext_a, vector<vector<uint64_t>> message_b, int rol_a, BatchEncoder &batch_encoder,
Evaluator &evaluator, GaloisKeys galois_keys)
{
Ciphertext ciphertext_a, vector<vector<uint64_t>> message_b, int rol_a, BatchEncoder &batch_encoder,
Evaluator &evaluator, GaloisKeys galois_keys) {
vector<Ciphertext> cipher_matrix(message_b.size());
for (int i = 0; i < message_b.size(); i++)
cipher_matrix[i] =
server_compute_col(ciphertext_a, message_b[i], rol_a, batch_encoder, evaluator, galois_keys);
server_compute_col(ciphertext_a, message_b[i], rol_a, batch_encoder, evaluator, galois_keys);
return cipher_matrix;
}

vector<vector<uint64_t>> client_decrypt(
vector<Ciphertext> cipher_matrix, Decryptor &decryptor, BatchEncoder &batch_encoder, int rol, int col)
{
vector<Ciphertext> cipher_matrix, Decryptor &decryptor, BatchEncoder &batch_encoder, int rol, int col) {
vector<vector<uint64_t>> message_result(rol, vector<uint64_t>(col));
for (int i = 0; i < cipher_matrix.size(); i++)
{
for (int i = 0; i < cipher_matrix.size(); i++) {
Plaintext plain_col_i;
decryptor.decrypt(cipher_matrix[i], plain_col_i);
vector<uint64_t> message_col_i;
Expand All @@ -167,13 +152,12 @@ class BGV
}

public:
BGV(int rol_a, int col_a)
{
BGV(int rol_a, int col_a) {
this->rol_a = rol_a;
this->col_a = col_a;
}
void experiment(int flops, int log_epoch, int validation)
{

void experiment(int flops, int log_epoch, int validation) {
// log
string log_name = "bgv/" + to_string(rol_a) + 'x' + to_string(col_a) + ".log";
freopen(log_name.c_str(), "w", stdout);
Expand Down Expand Up @@ -203,55 +187,61 @@ class BGV
BatchEncoder batch_encoder(context);

// experiments
double client_time = 0;
double server_time = 0;
clock_t client_start, client_end, server_start, server_end;

for (int i = 0; i < flops; i++)
{
double client_enc_time = 0;
double client_dec_time = 0;
double cache = 0.0;
clock_t client_enc_start, client_enc_end;
clock_t client_dec_start, client_dec_end;
clock_t server_start, server_end;

for (int i = 0; i < flops; i++) {
// init message
srand(unsigned(time(nullptr)));
vector<vector<uint64_t>> message_a = init_message(rol_a, col_a, rand());
vector<vector<uint64_t>> message_b = init_message(col_a, col_a, rand());

// client a: convert message a to ciphertext
client_start = clock();
client_enc_start = clock();
Ciphertext ciphertext_a = matrix_to_ciphertext(message_a, batch_encoder, encryptor);
client_enc_end = clock();
client_enc_time += double(client_enc_end - client_enc_start) / CLOCKS_PER_SEC;
cache = sizeof(ciphertext_a) / (1024.0 * 1024.0);

// server: compute
server_start = clock();
vector<Ciphertext> cipher_res =
server_compute(ciphertext_a, message_b, rol_a, batch_encoder, evaluator, galois_keys);
server_compute(ciphertext_a, message_b, rol_a, batch_encoder, evaluator, galois_keys);
server_end = clock();
server_time += double(server_end - server_start) / CLOCKS_PER_SEC;

// client a: decrypt
cache += sizeof(cipher_res) / (1024.0 * 1024.0);
client_dec_start = clock();
vector<vector<uint64_t>> message_res = client_decrypt(cipher_res, decryptor, batch_encoder, rol_a, col_a);
client_end = clock();
client_time += double(client_end - client_start) / CLOCKS_PER_SEC;
client_dec_end = clock();
client_dec_time += double(client_dec_end - client_dec_start) / CLOCKS_PER_SEC;

if ((i + 1) % log_epoch == 0)
{
if ((i + 1) % log_epoch == 0) {
if (validation == 0)
cout << "[" << i + 1 << "|" << flops << "] [" << server_time << "|" << client_time << "]" << endl;
else
{
cout << "[" << i + 1 << "|" << flops << "] [" << client_enc_time << "|" << client_dec_time << "|"
<< server_time << "|" << cache << "]" << endl;
else {
vector<vector<uint64_t>> ans = matrix_mul(message_a, message_b);
string val;
if (matrix_equal(ans, message_res))
val = "correct";
else
val = "wrong";
cout << "[" << i + 1 << "|" << flops << "] [" << server_time << "|" << client_time << "]: " << val
<< endl;
cout << "[" << i + 1 << "|" << flops << "] [" << client_enc_time << "|" << client_dec_time << "|"
<< server_time << "|" << cache << "]: " << val << endl;
}
}
}
}
};

void experiment_bgv(int rol_a, int col_a, int flops, int log_epoch, int validation)
{
void experiment_bgv(int rol_a, int col_a, int flops, int log_epoch, int validation) {
BGV bgv(rol_a, col_a);
bgv.experiment(flops, log_epoch, validation);
}
29 changes: 15 additions & 14 deletions native/dev/examples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,20 @@ int main() {
int log_epoch;
cin >> log_epoch;

for (int N = 1; N <= 64; N++)
experiment_ckks(N, N, flops, log_epoch, 0);
for (int N = 2; N <= 64; N++)
experiment_ckks(N, N - 1, flops, log_epoch, 0);

for (int N = 1; N <= 64; N++)
experiment_bgv(N, N, flops, log_epoch, validation);
for (int N = 2; N <= 64; N++)
experiment_bgv(N, N - 1, flops, log_epoch, validation);

for (int N = 1; N <= 64; N++)
experiment_bfv(N, N, flops, log_epoch, validation);
for (int N = 2; N <= 64; N++)
experiment_bfv(N, N - 1, flops, log_epoch, validation);
experiment_bgv(64, 64, flops, log_epoch, validation);
// for (int N = 1; N <= 64; N++)
// experiment_ckks(N, N, flops, log_epoch, 0);
// for (int N = 2; N <= 64; N++)
// experiment_ckks(N, N - 1, flops, log_epoch, 0);
//
// for (int N = 1; N <= 64; N++)
// experiment_bgv(N, N, flops, log_epoch, validation);
// for (int N = 2; N <= 64; N++)
// experiment_bgv(N, N - 1, flops, log_epoch, validation);
//
// for (int N = 1; N <= 64; N++)
// experiment_bfv(N, N, flops, log_epoch, validation);
// for (int N = 2; N <= 64; N++)
// experiment_bfv(N, N - 1, flops, log_epoch, validation);
return 0;
}

0 comments on commit 19e3c98

Please sign in to comment.