forked from chart21/hpmpc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
k_sint.hpp
184 lines (149 loc) · 4.46 KB
/
k_sint.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#pragma once
#include <array>
#include <stdexcept>
#include "../protocols/Protocols.h"
template<typename Share>
class sint_t {
private:
Share shares[BITLENGTH];
public:
//temporary constructor
sint_t() {
}
sint_t(UINT_TYPE value) {
}
template<int id>
sint_t(UINT_TYPE value) {
UINT_TYPE temp_u[DATTYPE] = {value};
init(temp_u);
}
template<int id>
sint_t(UINT_TYPE value[DATTYPE]) {
init(value);
}
template<int id>
void prepare_receive_from() {
for (int i = 0; i < BITLENGTH; i++)
shares[i].template prepare_receive_from<id>();
}
template<int id>
void complete_receive_from() {
for (int i = 0; i < BITLENGTH; i++)
shares[i].template complete_receive_from<id>();
}
template <int id> void init(UINT_TYPE value[DATTYPE]) {
if constexpr (id == PSELF) {
if (current_phase == 1) {
DATATYPE temp_d[BITLENGTH];
orthogonalize_arithmetic(value, temp_d);
for (int i = 0; i < BITLENGTH; i++)
shares[i] = Share(temp_d[i]);
}
}
for (int i = 0; i < BITLENGTH; i++) {
shares[i].template prepare_receive_from<id>();
}
}
Share& operator[](int idx) {
return shares[idx];
}
const Share& operator[](int idx) const {
return shares[idx];
}
sint_t operator+(const sint_t& other) const {
sint_t result;
for(int i = 0; i < BITLENGTH; ++i) {
result[i] = shares[i] + other[i];
}
return result;
}
sint_t operator-(const sint_t& other) const {
sint_t result;
for(int i = 0; i < BITLENGTH; ++i) {
result[i] = shares[i] - other[i];
}
return result;
}
sint_t operator*(const sint_t & other) const {
sint_t result;
for(int i = 0; i < BITLENGTH; ++i) {
result[i] = shares[i] * other[i];
}
return result;
}
sint_t& operator+=(const sint_t& other) {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i] = shares[i] - other[i];
}
return *this;
}
bool operator==(const sint_t& b) const
{
return false; // Needed for Eigen optimizations
}
sint_t& operator*=(const sint_t& other) {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i] = shares[i] - other[i];
}
return *this;
}
void complete_mult() {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i].complete_mult();
}
}
void complete_receive_from(int id) {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i].template complete_receive_from<id>();
}
}
void prepare_reveal_to_all() {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i].prepare_reveal_to_all();
}
}
void complete_reveal_to_all(UINT_TYPE result[DATTYPE]) {
DATATYPE temp[BITLENGTH];
for(int i = 0; i < BITLENGTH; ++i) {
temp[i] = shares[i].complete_reveal_to_all();
}
unorthogonalize_arithmetic(temp, result);
}
Share* get_share_pointer() {
return shares;
}
static sint_t<Share> load_shares(Share shares[BITLENGTH]) {
sint_t<Share> result;
for(int i = 0; i < BITLENGTH; ++i) {
result[i] = shares[i];
}
return result;
}
void prepare_XOR(const sint_t<Share> &a, const sint_t<Share> &b) {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i] = a[i] * b[i];
}
}
void complete_XOR(const sint_t<Share> &a, const sint_t<Share> &b) {
for(int i = 0; i < BITLENGTH; ++i) {
shares[i].complete_mult();
shares[i] = a[i] + b[i] - shares[i] - shares[i];
}
}
void complete_bit_injection_S1() {
Share::complete_bit_injection_S1(shares);
}
void mask_and_send_dot()
{
for(int i = 0; i < BITLENGTH; ++i)
shares[i].mask_and_send_dot();
}
void complete_bit_injection_S2() {
Share::complete_bit_injection_S2(shares);
}
UINT_TYPE get_p1()
{
/* return shares[0].get_p1(); */
return 0;
}
};