From 9f8901bf89d8eb8781bac4002e47b7467ef7a1c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Tue, 27 Feb 2024 03:56:01 +0100 Subject: [PATCH 1/9] feat: add FheBool and FheUint4 --- fhevm/contracts_test.go | 286 +++++++- fhevm/fhelib_run.go | 13 + fhevm/params.go | 25 + fhevm/tfhe_ciphertext.go | 1258 ++++++++++++++++++++++++++-------- fhevm/tfhe_key_management.go | 4 + fhevm/tfhe_test.go | 363 +++++++++- fhevm/tfhe_wrappers.c | 708 +++++++++++++++++++ fhevm/tfhe_wrappers.go | 12 + fhevm/tfhe_wrappers.h | 142 +++- kms/kms.pb.go | 27 +- tfhe-rs | 2 +- 11 files changed, 2524 insertions(+), 316 deletions(-) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 889b0fd..7cef262 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -171,8 +171,12 @@ func prepareInputForVerifyCiphertext(input []byte) []byte { func VerifyCiphertext(t *testing.T, fheUintType FheUintType) { var value uint64 switch fheUintType { + case FheBool: + value = 1 + case FheUint4: + value = 4 case FheUint8: - value = 2 + value = 234 case FheUint16: value = 4283 case FheUint32: @@ -207,6 +211,8 @@ func VerifyCiphertext(t *testing.T, fheUintType FheUintType) { func VerifyCiphertextBadType(t *testing.T, actualType FheUintType, metadataType FheUintType) { var value uint64 switch actualType { + case FheUint4: + value = 2 case FheUint8: value = 2 case FheUint16: @@ -235,6 +241,8 @@ func VerifyCiphertextBadType(t *testing.T, actualType FheUintType, metadataType func TrivialEncrypt(t *testing.T, fheUintType FheUintType) { var value big.Int switch fheUintType { + case FheUint4: + value = *big.NewInt(2) case FheUint8: value = *big.NewInt(2) case FheUint16: @@ -268,6 +276,9 @@ func TrivialEncrypt(t *testing.T, fheUintType FheUintType) { func FheLibAdd(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -314,6 +325,9 @@ func FheLibAdd(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibSub(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -359,6 +373,9 @@ func FheLibSub(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibMul(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 3 + rhs = 2 case FheUint8: lhs = 3 rhs = 2 @@ -404,6 +421,9 @@ func FheLibMul(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibLe(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -469,6 +489,9 @@ func FheLibLe(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibLt(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -535,6 +558,9 @@ func FheLibLt(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibEq(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -580,6 +606,9 @@ func FheLibEq(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibGe(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -643,6 +672,9 @@ func FheLibGe(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibGt(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -708,6 +740,9 @@ func FheLibGt(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibShl(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -753,6 +788,9 @@ func FheLibShl(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibShr(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -798,6 +836,9 @@ func FheLibShr(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibNe(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -843,6 +884,9 @@ func FheLibNe(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibMin(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -907,6 +951,9 @@ func FheLibMin(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibMax(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -971,6 +1018,9 @@ func FheLibMax(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibNeg(t *testing.T, fheUintType FheUintType) { var pt, expected uint64 switch fheUintType { + case FheUint4: + pt = 7 + expected = uint64(16 - uint8(pt)) case FheUint8: pt = 2 expected = uint64(-uint8(pt)) @@ -1011,6 +1061,9 @@ func FheLibNeg(t *testing.T, fheUintType FheUintType) { func FheLibNot(t *testing.T, fheUintType FheUintType) { var pt, expected uint64 switch fheUintType { + case FheUint4: + pt = 5 + expected = uint64(15 - uint8(pt)) case FheUint8: pt = 2 expected = uint64(^uint8(pt)) @@ -1051,6 +1104,9 @@ func FheLibNot(t *testing.T, fheUintType FheUintType) { func FheLibDiv(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 4 + rhs = 2 case FheUint8: lhs = 4 rhs = 2 @@ -1103,6 +1159,9 @@ func FheLibDiv(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibRem(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 7 + rhs = 3 case FheUint8: lhs = 7 rhs = 3 @@ -1154,6 +1213,12 @@ func FheLibRem(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibBitAnd(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheBool: + lhs = 1 + rhs = 0 + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1205,6 +1270,12 @@ func FheLibBitAnd(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibBitOr(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheBool: + lhs = 1 + rhs = 0 + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1256,6 +1327,12 @@ func FheLibBitOr(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLibBitXor(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheBool: + lhs = 1 + rhs = 0 + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1335,6 +1412,10 @@ func FheLibRand(t *testing.T, fheUintType FheUintType) { t.Fatalf("decrypted value is not 64 bit") } switch fheUintType { + case FheUint4: + if decrypted.Uint64() > 0xF { + t.Fatalf("random value is bigger than 0xFF (4 bits)") + } case FheUint8: if decrypted.Uint64() > 0xFF { t.Fatalf("random value is bigger than 0xFF (8 bits)") @@ -1394,6 +1475,9 @@ func FheLibRandBounded(t *testing.T, fheUintType FheUintType, upperBound64 uint6 func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { var second, third uint64 switch fheUintType { + case FheUint4: + second = 2 + third = 1 case FheUint8: second = 2 third = 1 @@ -1435,6 +1519,8 @@ func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { func LibTrivialEncrypt(t *testing.T, fheUintType FheUintType) { var value big.Int switch fheUintType { + case FheUint4: + value = *big.NewInt(2) case FheUint8: value = *big.NewInt(2) case FheUint16: @@ -1474,6 +1560,8 @@ func LibTrivialEncrypt(t *testing.T, fheUintType FheUintType) { func LibDecrypt(t *testing.T, fheUintType FheUintType) { var value uint64 switch fheUintType { + case FheUint4: + value = 2 case FheUint8: value = 2 case FheUint16: @@ -1584,6 +1672,9 @@ func TestLibCast(t *testing.T) { func FheAdd(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1628,6 +1719,9 @@ func FheAdd(t *testing.T, fheUintType FheUintType, scalar bool) { func FheSub(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1672,6 +1766,9 @@ func FheSub(t *testing.T, fheUintType FheUintType, scalar bool) { func FheMul(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 3 case FheUint8: lhs = 2 rhs = 3 @@ -1716,6 +1813,9 @@ func FheMul(t *testing.T, fheUintType FheUintType, scalar bool) { func FheDiv(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 6 + rhs = 7 case FheUint8: lhs = 6 rhs = 7 @@ -1766,6 +1866,9 @@ func FheDiv(t *testing.T, fheUintType FheUintType, scalar bool) { func FheRem(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 9 + rhs = 5 case FheUint8: lhs = 9 rhs = 5 @@ -1816,6 +1919,9 @@ func FheRem(t *testing.T, fheUintType FheUintType, scalar bool) { func FheBitAnd(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1866,6 +1972,9 @@ func FheBitAnd(t *testing.T, fheUintType FheUintType, scalar bool) { func FheBitOr(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1916,6 +2025,9 @@ func FheBitOr(t *testing.T, fheUintType FheUintType, scalar bool) { func FheBitXor(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -1966,6 +2078,9 @@ func FheBitXor(t *testing.T, fheUintType FheUintType, scalar bool) { func FheShl(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2010,6 +2125,9 @@ func FheShl(t *testing.T, fheUintType FheUintType, scalar bool) { func FheShr(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2054,6 +2172,9 @@ func FheShr(t *testing.T, fheUintType FheUintType, scalar bool) { func FheEq(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2098,6 +2219,9 @@ func FheEq(t *testing.T, fheUintType FheUintType, scalar bool) { func FheNe(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2142,6 +2266,9 @@ func FheNe(t *testing.T, fheUintType FheUintType, scalar bool) { func FheGe(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2204,6 +2331,9 @@ func FheGe(t *testing.T, fheUintType FheUintType, scalar bool) { func FheGt(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2268,6 +2398,9 @@ func FheGt(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLe(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2332,6 +2465,9 @@ func FheLe(t *testing.T, fheUintType FheUintType, scalar bool) { func FheLt(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2397,6 +2533,9 @@ func FheLt(t *testing.T, fheUintType FheUintType, scalar bool) { func FheMin(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2460,6 +2599,9 @@ func FheMin(t *testing.T, fheUintType FheUintType, scalar bool) { func FheMax(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2523,6 +2665,9 @@ func FheMax(t *testing.T, fheUintType FheUintType, scalar bool) { func FheNeg(t *testing.T, fheUintType FheUintType, scalar bool) { var pt, expected uint64 switch fheUintType { + case FheUint4: + pt = 2 + expected = uint64(-uint8(pt)) case FheUint8: pt = 2 expected = uint64(-uint8(pt)) @@ -2563,6 +2708,9 @@ func FheNeg(t *testing.T, fheUintType FheUintType, scalar bool) { func FheNot(t *testing.T, fheUintType FheUintType, scalar bool) { var pt, expected uint64 switch fheUintType { + case FheUint4: + pt = 2 + expected = uint64(^uint8(pt)) case FheUint8: pt = 2 expected = uint64(^uint8(pt)) @@ -2603,6 +2751,9 @@ func FheNot(t *testing.T, fheUintType FheUintType, scalar bool) { func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { var lhs, rhs uint64 switch fheUintType { + case FheUint4: + lhs = 2 + rhs = 1 case FheUint8: lhs = 2 rhs = 1 @@ -2643,6 +2794,8 @@ func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { func Decrypt(t *testing.T, fheUintType FheUintType) { var value uint64 switch fheUintType { + case FheUint4: + value = 2 case FheUint8: value = 2 case FheUint16: @@ -2759,6 +2912,10 @@ func TestVerifyCiphertextInvalidSize(t *testing.T) { } } +func TestVerifyCiphertext4(t *testing.T) { + VerifyCiphertext(t, FheUint4) +} + func TestVerifyCiphertext8(t *testing.T) { VerifyCiphertext(t, FheUint8) } @@ -2775,6 +2932,10 @@ func TestVerifyCiphertext64(t *testing.T) { VerifyCiphertext(t, FheUint64) } +func TestTrivialEncrypt4(t *testing.T) { + TrivialEncrypt(t, FheUint4) +} + func TestTrivialEncrypt8(t *testing.T) { TrivialEncrypt(t, FheUint8) } @@ -2791,24 +2952,39 @@ func TestTrivialEncrypt64(t *testing.T) { TrivialEncrypt(t, FheUint64) } +func TestVerifyCiphertext4BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint4, FheUint8) + VerifyCiphertextBadType(t, FheUint4, FheUint16) + VerifyCiphertextBadType(t, FheUint4, FheUint32) + VerifyCiphertextBadType(t, FheUint4, FheUint64) +} + func TestVerifyCiphertext8BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint8, FheUint4) VerifyCiphertextBadType(t, FheUint8, FheUint16) VerifyCiphertextBadType(t, FheUint8, FheUint32) + VerifyCiphertextBadType(t, FheUint8, FheUint64) } func TestVerifyCiphertext16BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint16, FheUint4) VerifyCiphertextBadType(t, FheUint16, FheUint8) VerifyCiphertextBadType(t, FheUint16, FheUint32) + VerifyCiphertextBadType(t, FheUint16, FheUint64) } func TestVerifyCiphertext32BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint32, FheUint4) VerifyCiphertextBadType(t, FheUint32, FheUint8) VerifyCiphertextBadType(t, FheUint32, FheUint16) + VerifyCiphertextBadType(t, FheUint32, FheUint64) } func TestVerifyCiphertext64BadType(t *testing.T) { + VerifyCiphertextBadType(t, FheUint64, FheUint4) VerifyCiphertextBadType(t, FheUint64, FheUint8) VerifyCiphertextBadType(t, FheUint64, FheUint16) + VerifyCiphertextBadType(t, FheUint64, FheUint32) } func TestVerifyCiphertextBadCiphertext(t *testing.T) { @@ -2827,6 +3003,102 @@ func TestVerifyCiphertextBadCiphertext(t *testing.T) { } } +func TestFheLibBitAndBool(t *testing.T) { + FheLibBitAnd(t, FheBool, false) +} + +func TestFheLibBitOrBool(t *testing.T) { + FheLibBitOr(t, FheBool, false) +} + +func TestFheLibBitXorBool(t *testing.T) { + FheLibBitXor(t, FheBool, false) +} + +func TestFheLibAdd4(t *testing.T) { + FheLibAdd(t, FheUint4, false) +} + +func TestFheLibSub4(t *testing.T) { + FheLibSub(t, FheUint4, false) +} + +func TestFheLibMul4(t *testing.T) { + FheLibMul(t, FheUint4, false) +} + +func TestFheLibLe4(t *testing.T) { + FheLibLe(t, FheUint4, false) +} + +func TestFheLibLt4(t *testing.T) { + FheLibLt(t, FheUint4, false) +} + +func TestFheLibEq4(t *testing.T) { + FheLibEq(t, FheUint4, false) +} + +func TestFheLibGe4(t *testing.T) { + FheLibGe(t, FheUint4, false) +} + +func TestFheLibGt4(t *testing.T) { + FheLibGt(t, FheUint4, false) +} + +func TestFheLibShl4(t *testing.T) { + FheLibShl(t, FheUint4, false) +} + +func TestFheLibShr4(t *testing.T) { + FheLibShr(t, FheUint4, false) +} + +func TestFheLibNe4(t *testing.T) { + FheLibNe(t, FheUint4, false) +} + +func TestFheLibMin4(t *testing.T) { + FheLibMin(t, FheUint4, false) +} + +func TestFheLibMax4(t *testing.T) { + FheLibMax(t, FheUint4, false) +} + +func TestFheLibNeg4(t *testing.T) { + FheLibNeg(t, FheUint4) +} + +func TestFheLibNot4(t *testing.T) { + FheLibNot(t, FheUint4) +} + +func TestFheLibDiv4(t *testing.T) { + FheLibDiv(t, FheUint4, true) +} + +func TestFheLibRem4(t *testing.T) { + FheLibRem(t, FheUint4, true) +} + +func TestFheLibBitAnd4(t *testing.T) { + FheLibBitAnd(t, FheUint4, false) +} + +func TestFheLibBitOr4(t *testing.T) { + FheLibBitOr(t, FheUint4, false) +} + +func TestFheLibBitXor4(t *testing.T) { + FheLibBitXor(t, FheUint4, false) +} + +func TestFheLibRand4(t *testing.T) { + FheLibRand(t, FheUint4) +} + func TestFheLibAdd8(t *testing.T) { FheLibAdd(t, FheUint8, false) } @@ -3224,6 +3496,10 @@ func TestFheScalarBitXor64(t *testing.T) { FheBitXor(t, FheUint64, true) } +func TestFheShl4(t *testing.T) { + FheShl(t, FheUint4, false) +} + func TestFheShl8(t *testing.T) { FheShl(t, FheUint8, false) } @@ -3512,6 +3788,10 @@ func TestFheScalarMin64(t *testing.T) { FheMin(t, FheUint64, true) } +func TestFheMax4(t *testing.T) { + FheMax(t, FheUint4, false) +} + func TestFheMax8(t *testing.T) { FheMax(t, FheUint8, false) } @@ -3580,6 +3860,10 @@ func TestFheIfThenElse64(t *testing.T) { FheIfThenElse(t, FheUint64, 0) } +func TestFheScalarMax4(t *testing.T) { + FheMax(t, FheUint4, true) +} + func TestFheScalarMax8(t *testing.T) { FheMax(t, FheUint8, true) } diff --git a/fhevm/fhelib_run.go b/fhevm/fhelib_run.go index b885f91..99d2575 100644 --- a/fhevm/fhelib_run.go +++ b/fhevm/fhelib_run.go @@ -1257,6 +1257,11 @@ func generateRandom(environment EVMEnvironment, caller common.Address, resultTyp // Apply upperBound, if set. var randUint uint64 switch resultType { + case FheUint4: + randBytes := make([]byte, 1) + cipher.XORKeyStream(randBytes, randBytes) + randUint = uint64(randBytes[0]) + randUint = uint64(applyUpperBound(randUint, 4, upperBound)) case FheUint8: randBytes := make([]byte, 1) cipher.XORKeyStream(randBytes, randBytes) @@ -1462,6 +1467,10 @@ func reencryptRun(environment EVMEnvironment, caller common.Address, addr common var fheType kms.FheType switch ct.fheUintType() { + case FheBool: + fheType = kms.FheType_Bool + case FheUint4: + fheType = kms.FheType_Euint4 case FheUint8: fheType = kms.FheType_Euint8 case FheUint16: @@ -1632,6 +1641,10 @@ func decryptValue(environment EVMEnvironment, ct *TfheCiphertext) (uint64, error logger := environment.GetLogger() var fheType kms.FheType switch ct.Type() { + case FheBool: + fheType = kms.FheType_Bool + case FheUint4: + fheType = kms.FheType_Euint4 case FheUint8: fheType = kms.FheType_Euint8 case FheUint16: diff --git a/fhevm/params.go b/fhevm/params.go index 83990ed..47e9e70 100644 --- a/fhevm/params.go +++ b/fhevm/params.go @@ -67,84 +67,99 @@ type GasCosts struct { func DefaultGasCosts() GasCosts { return GasCosts{ FheAddSub: map[FheUintType]uint64{ + FheUint4: 84000 + AdjustFHEGas, FheUint8: 84000 + AdjustFHEGas, FheUint16: 123000 + AdjustFHEGas, FheUint32: 152000 + AdjustFHEGas, FheUint64: 178000 + AdjustFHEGas, }, FheDecrypt: map[FheUintType]uint64{ + FheUint4: 500000, FheUint8: 500000, FheUint16: 500000, FheUint32: 500000, FheUint64: 500000, }, FheBitwiseOp: map[FheUintType]uint64{ + FheBool: 23000 + AdjustFHEGas, + FheUint4: 24000 + AdjustFHEGas, FheUint8: 24000 + AdjustFHEGas, FheUint16: 24000 + AdjustFHEGas, FheUint32: 25000 + AdjustFHEGas, FheUint64: 28000 + AdjustFHEGas, }, FheMul: map[FheUintType]uint64{ + FheUint4: 187000 + AdjustFHEGas, FheUint8: 187000 + AdjustFHEGas, FheUint16: 252000 + AdjustFHEGas, FheUint32: 349000 + AdjustFHEGas, FheUint64: 631000 + AdjustFHEGas, }, FheScalarMul: map[FheUintType]uint64{ + FheUint4: 149000 + AdjustFHEGas, FheUint8: 149000 + AdjustFHEGas, FheUint16: 198000 + AdjustFHEGas, FheUint32: 254000 + AdjustFHEGas, FheUint64: 346000 + AdjustFHEGas, }, FheScalarDiv: map[FheUintType]uint64{ + FheUint4: 228000 + AdjustFHEGas, FheUint8: 228000 + AdjustFHEGas, FheUint16: 304000 + AdjustFHEGas, FheUint32: 388000 + AdjustFHEGas, FheUint64: 574000 + AdjustFHEGas, }, FheScalarRem: map[FheUintType]uint64{ + FheUint4: 450000 + AdjustFHEGas, FheUint8: 450000 + AdjustFHEGas, FheUint16: 612000 + AdjustFHEGas, FheUint32: 795000 + AdjustFHEGas, FheUint64: 1095000 + AdjustFHEGas, }, FheShift: map[FheUintType]uint64{ + FheUint4: 123000 + AdjustFHEGas, FheUint8: 123000 + AdjustFHEGas, FheUint16: 143000 + AdjustFHEGas, FheUint32: 173000 + AdjustFHEGas, FheUint64: 217000 + AdjustFHEGas, }, FheScalarShift: map[FheUintType]uint64{ + FheUint4: 25000 + AdjustFHEGas, FheUint8: 25000 + AdjustFHEGas, FheUint16: 25000 + AdjustFHEGas, FheUint32: 25000 + AdjustFHEGas, FheUint64: 28000 + AdjustFHEGas, }, FheLe: map[FheUintType]uint64{ + FheUint4: 46000 + AdjustFHEGas, FheUint8: 46000 + AdjustFHEGas, FheUint16: 46000 + AdjustFHEGas, FheUint32: 72000 + AdjustFHEGas, FheUint64: 76000 + AdjustFHEGas, }, FheMinMax: map[FheUintType]uint64{ + FheUint4: 94000 + AdjustFHEGas, FheUint8: 94000 + AdjustFHEGas, FheUint16: 120000 + AdjustFHEGas, FheUint32: 148000 + AdjustFHEGas, FheUint64: 189000 + AdjustFHEGas, }, FheScalarMinMax: map[FheUintType]uint64{ + FheUint4: 114000 + AdjustFHEGas, FheUint8: 114000 + AdjustFHEGas, FheUint16: 140000 + AdjustFHEGas, FheUint32: 154000 + AdjustFHEGas, FheUint64: 182000 + AdjustFHEGas, }, FheNot: map[FheUintType]uint64{ + FheUint4: 25000 + AdjustFHEGas, FheUint8: 25000 + AdjustFHEGas, FheUint16: 25000 + AdjustFHEGas, FheUint32: 26000 + AdjustFHEGas, FheUint64: 27000 + AdjustFHEGas, }, FheNeg: map[FheUintType]uint64{ + FheUint4: 79000 + AdjustFHEGas, FheUint8: 79000 + AdjustFHEGas, FheUint16: 114000 + AdjustFHEGas, FheUint32: 150000 + AdjustFHEGas, @@ -152,18 +167,24 @@ func DefaultGasCosts() GasCosts { }, // TODO: Costs will depend on the complexity of doing reencryption/decryption by the oracle. FheReencrypt: map[FheUintType]uint64{ + FheBool: 1000, + FheUint4: 1000, FheUint8: 1000, FheUint16: 1100, FheUint32: 1200, }, // As of now, verification costs only cover ciphertext deserialization and assume there is no ZKPoK to verify. FheVerify: map[FheUintType]uint64{ + FheBool: 200, + FheUint4: 200, FheUint8: 200, FheUint16: 300, FheUint32: 400, FheUint64: 800, }, FheTrivialEncrypt: map[FheUintType]uint64{ + FheBool: 100, + FheUint4: 100, FheUint8: 100, FheUint16: 200, FheUint32: 300, @@ -171,12 +192,14 @@ func DefaultGasCosts() GasCosts { }, // TODO: These will change once we have an FHE-based random generaration. FheRand: map[FheUintType]uint64{ + FheUint4: EvmNetSstoreInitGas + 100000, FheUint8: EvmNetSstoreInitGas + 100000, FheUint16: EvmNetSstoreInitGas + 100000, FheUint32: EvmNetSstoreInitGas + 100000, FheUint64: EvmNetSstoreInitGas + 100000, }, FheIfThenElse: map[FheUintType]uint64{ + FheUint4: 37000 + AdjustFHEGas, FheUint8: 37000 + AdjustFHEGas, FheUint16: 37000 + AdjustFHEGas, FheUint32: 40000 + AdjustFHEGas, @@ -188,11 +211,13 @@ func DefaultGasCosts() GasCosts { // For every subsequent optimistic require, we need to bitand it with the current require value - that // works, because we assume requires have a value of 0 or 1. FheOptRequire: map[FheUintType]uint64{ + FheUint4: 170000, FheUint8: 170000, FheUint16: 180000, FheUint32: 190000, }, FheOptRequireBitAnd: map[FheUintType]uint64{ + FheUint4: 20000, FheUint8: 20000, FheUint16: 20000, FheUint32: 20000, diff --git a/fhevm/tfhe_ciphertext.go b/fhevm/tfhe_ciphertext.go index 0839079..e841f5b 100644 --- a/fhevm/tfhe_ciphertext.go +++ b/fhevm/tfhe_ciphertext.go @@ -17,10 +17,12 @@ import ( type FheUintType uint8 const ( - FheUint8 FheUintType = 0 - FheUint16 FheUintType = 1 - FheUint32 FheUintType = 2 - FheUint64 FheUintType = 3 + FheBool FheUintType = 0 + FheUint4 FheUintType = 1 + FheUint8 FheUintType = 2 + FheUint16 FheUintType = 3 + FheUint32 FheUintType = 4 + FheUint64 FheUintType = 5 ) func (t FheUintType) String() string { @@ -39,7 +41,7 @@ func (t FheUintType) String() string { } func isValidFheType(t byte) bool { - if uint8(t) < uint8(FheUint8) || uint8(t) > uint8(FheUint64) { + if uint8(t) < uint8(FheBool) || uint8(t) > uint8(FheUint64) { return false } return true @@ -55,10 +57,33 @@ type TfheCiphertext struct { func (ct *TfheCiphertext) Type() FheUintType { return ct.fheUintType } +func boolBinaryNotSupportedOp(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return nil, errors.New("Bool is not supported") +} + +func boolBinaryScalarNotSupportedOp(lhs unsafe.Pointer, rhs C.bool) (unsafe.Pointer, error) { + return nil, errors.New("Bool is not supported") +} + +func boolUnaryNotSupportedOp(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return nil, errors.New("Bool is not supported") +} // Deserializes a TFHE ciphertext. func (ct *TfheCiphertext) Deserialize(in []byte, t FheUintType) error { switch t { + case FheBool: + ptr := C.deserialize_fhe_bool(toDynamicBufferView((in))) + if ptr == nil { + return errors.New("FheBool ciphertext deserialization failed") + } + C.destroy_fhe_bool(ptr) + case FheUint4: + ptr := C.deserialize_fhe_uint4(toDynamicBufferView((in))) + if ptr == nil { + return errors.New("FheUint4 ciphertext deserialization failed") + } + C.destroy_fhe_uint4(ptr) case FheUint8: ptr := C.deserialize_fhe_uint8(toDynamicBufferView((in))) if ptr == nil { @@ -97,6 +122,28 @@ func (ct *TfheCiphertext) Deserialize(in []byte, t FheUintType) error { // will produce non-compact ciphertext serialziations. func (ct *TfheCiphertext) DeserializeCompact(in []byte, t FheUintType) error { switch t { + case FheBool: + ptr := C.deserialize_compact_fhe_bool(toDynamicBufferView((in))) + if ptr == nil { + return errors.New("compact FheBool ciphertext deserialization failed") + } + var err error + ct.serialization, err = serialize(ptr, t) + C.destroy_fhe_bool(ptr) + if err != nil { + return err + } + case FheUint4: + ptr := C.deserialize_compact_fhe_uint4(toDynamicBufferView((in))) + if ptr == nil { + return errors.New("compact FheUint4 ciphertext deserialization failed") + } + var err error + ct.serialization, err = serialize(ptr, t) + C.destroy_fhe_uint4(ptr) + if err != nil { + return err + } case FheUint8: ptr := C.deserialize_compact_fhe_uint8(toDynamicBufferView((in))) if ptr == nil { @@ -155,6 +202,24 @@ func (ct *TfheCiphertext) Encrypt(value big.Int, t FheUintType) *TfheCiphertext var ptr unsafe.Pointer var err error switch t { + case FheBool: + val := false + if value.Uint64() > 0 { + val = true + } + ptr = C.public_key_encrypt_fhe_bool(pks, C.bool(val)) + ct.serialization, err = serialize(ptr, t) + C.destroy_fhe_bool(ptr) + if err != nil { + panic(err) + } + case FheUint4: + ptr = C.public_key_encrypt_fhe_uint4(pks, C.uint8_t(value.Uint64())) + ct.serialization, err = serialize(ptr, t) + C.destroy_fhe_uint4(ptr) + if err != nil { + panic(err) + } case FheUint8: ptr = C.public_key_encrypt_fhe_uint8(pks, C.uint8_t(value.Uint64())) ct.serialization, err = serialize(ptr, t) @@ -195,6 +260,24 @@ func (ct *TfheCiphertext) TrivialEncrypt(value big.Int, t FheUintType) *TfheCiph var ptr unsafe.Pointer var err error switch t { + case FheBool: + val := false + if value.Uint64() > 0 { + val = true + } + ptr = C.trivial_encrypt_fhe_bool(sks, C.bool(val)) + ct.serialization, err = serialize(ptr, t) + C.destroy_fhe_bool(ptr) + if err != nil { + panic(err) + } + case FheUint4: + ptr = C.trivial_encrypt_fhe_uint4(sks, C.uint8_t(value.Uint64())) + ct.serialization, err = serialize(ptr, t) + C.destroy_fhe_uint4(ptr) + if err != nil { + panic(err) + } case FheUint8: ptr = C.trivial_encrypt_fhe_uint8(sks, C.uint8_t(value.Uint64())) ct.serialization, err = serialize(ptr, t) @@ -236,21 +319,66 @@ func (ct *TfheCiphertext) Serialize() []byte { } func (ct *TfheCiphertext) executeUnaryCiphertextOperation(rhs *TfheCiphertext, - op8 func(ct unsafe.Pointer) unsafe.Pointer, - op16 func(ct unsafe.Pointer) unsafe.Pointer, - op32 func(ct unsafe.Pointer) unsafe.Pointer, - op64 func(ct unsafe.Pointer) unsafe.Pointer) (*TfheCiphertext, error) { + opBool func(ct unsafe.Pointer) (unsafe.Pointer, error), + op4 func(ct unsafe.Pointer) (unsafe.Pointer, error), + op8 func(ct unsafe.Pointer) (unsafe.Pointer, error), + op16 func(ct unsafe.Pointer) (unsafe.Pointer, error), + op32 func(ct unsafe.Pointer) (unsafe.Pointer, error), + op64 func(ct unsafe.Pointer) (unsafe.Pointer, error)) (*TfheCiphertext, error) { res := new(TfheCiphertext) res.fheUintType = ct.fheUintType res_ser := &C.DynamicBuffer{} switch ct.fheUintType { + case FheBool: + ct_ptr := C.deserialize_fhe_bool(toDynamicBufferView((ct.serialization))) + if ct_ptr == nil { + return nil, errors.New("Bool unary op deserialization failed") + } + res_ptr, err := opBool(ct_ptr) + if (err != nil) { + return nil, err; + } + C.destroy_fhe_bool(ct_ptr) + if res_ptr == nil { + return nil, errors.New("Bool unary op failed") + } + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool unary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) + case FheUint4: + ct_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((ct.serialization))) + if ct_ptr == nil { + return nil, errors.New("8 bit unary op deserialization failed") + } + res_ptr, err := op4(ct_ptr) + if (err != nil) { + return nil, err; + } + C.destroy_fhe_uint4(ct_ptr) + if res_ptr == nil { + return nil, errors.New("8 bit unary op failed") + } + ret := C.serialize_fhe_uint4(res_ptr, res_ser) + C.destroy_fhe_uint4(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit unary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) case FheUint8: ct_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((ct.serialization))) if ct_ptr == nil { return nil, errors.New("8 bit unary op deserialization failed") } - res_ptr := op8(ct_ptr) + res_ptr, err := op8(ct_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint8(ct_ptr) if res_ptr == nil { return nil, errors.New("8 bit unary op failed") @@ -267,7 +395,10 @@ func (ct *TfheCiphertext) executeUnaryCiphertextOperation(rhs *TfheCiphertext, if ct_ptr == nil { return nil, errors.New("16 bit unary op deserialization failed") } - res_ptr := op16(ct_ptr) + res_ptr, err := op16(ct_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint16(ct_ptr) if res_ptr == nil { return nil, errors.New("16 bit op failed") @@ -284,7 +415,10 @@ func (ct *TfheCiphertext) executeUnaryCiphertextOperation(rhs *TfheCiphertext, if ct_ptr == nil { return nil, errors.New("32 bit unary op deserialization failed") } - res_ptr := op32(ct_ptr) + res_ptr, err := op16(ct_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint32(ct_ptr) if res_ptr == nil { return nil, errors.New("32 bit op failed") @@ -301,7 +435,10 @@ func (ct *TfheCiphertext) executeUnaryCiphertextOperation(rhs *TfheCiphertext, if ct_ptr == nil { return nil, errors.New("64 bit unary op deserialization failed") } - res_ptr := op64(ct_ptr) + res_ptr, err := op64(ct_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint64(ct_ptr) if res_ptr == nil { return nil, errors.New("64 bit op failed") @@ -321,10 +458,12 @@ func (ct *TfheCiphertext) executeUnaryCiphertextOperation(rhs *TfheCiphertext, } func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, - op8 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, - op16 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, - op32 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, - op64 func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer) (*TfheCiphertext, error) { + opBool func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op4 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op8 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op16 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op32 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op64 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error)) (*TfheCiphertext, error) { if lhs.fheUintType != rhs.fheUintType { return nil, errors.New("binary operations are only well-defined for identical types") } @@ -333,6 +472,58 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, res.fheUintType = lhs.fheUintType res_ser := &C.DynamicBuffer{} switch lhs.fheUintType { + case FheBool: + lhs_ptr := C.deserialize_fhe_bool(toDynamicBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("bool binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_bool(toDynamicBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_bool(lhs_ptr) + return nil, errors.New("bool binary op deserialization failed") + } + res_ptr, err := opBool(lhs_ptr, rhs_ptr) + if (err != nil) { + return nil, err; + } + C.destroy_fhe_bool(lhs_ptr) + C.destroy_fhe_bool(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("bool binary op failed") + } + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) + case FheUint4: + lhs_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("4 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint4(lhs_ptr) + return nil, errors.New("4 bit binary op deserialization failed") + } + res_ptr, err := op4(lhs_ptr, rhs_ptr) + if (err != nil) { + return nil, err; + } + C.destroy_fhe_uint4(lhs_ptr) + C.destroy_fhe_uint4(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("4 bit binary op failed") + } + ret := C.serialize_fhe_uint4(res_ptr, res_ser) + C.destroy_fhe_uint4(res_ptr) + if ret != 0 { + return nil, errors.New("4 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) case FheUint8: lhs_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((lhs.serialization))) if lhs_ptr == nil { @@ -343,7 +534,10 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, C.destroy_fhe_uint8(lhs_ptr) return nil, errors.New("8 bit binary op deserialization failed") } - res_ptr := op8(lhs_ptr, rhs_ptr) + res_ptr, err := op8(lhs_ptr, rhs_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint8(lhs_ptr) C.destroy_fhe_uint8(rhs_ptr) if res_ptr == nil { @@ -366,7 +560,10 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, C.destroy_fhe_uint16(lhs_ptr) return nil, errors.New("16 bit binary op deserialization failed") } - res_ptr := op16(lhs_ptr, rhs_ptr) + res_ptr, err := op16(lhs_ptr, rhs_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint16(lhs_ptr) C.destroy_fhe_uint16(rhs_ptr) if res_ptr == nil { @@ -389,7 +586,10 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, C.destroy_fhe_uint32(lhs_ptr) return nil, errors.New("32 bit binary op deserialization failed") } - res_ptr := op32(lhs_ptr, rhs_ptr) + res_ptr, err := op32(lhs_ptr, rhs_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint32(lhs_ptr) C.destroy_fhe_uint32(rhs_ptr) if res_ptr == nil { @@ -412,7 +612,10 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, C.destroy_fhe_uint64(lhs_ptr) return nil, errors.New("64 bit binary op deserialization failed") } - res_ptr := op64(lhs_ptr, rhs_ptr) + res_ptr, err := op64(lhs_ptr, rhs_ptr) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint64(lhs_ptr) C.destroy_fhe_uint64(rhs_ptr) if res_ptr == nil { @@ -433,6 +636,7 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, } func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCiphertext, rhs *TfheCiphertext, + op4 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, op8 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, op16 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, op32 func(first unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer, @@ -445,6 +649,35 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte res.fheUintType = lhs.fheUintType res_ser := &C.DynamicBuffer{} switch lhs.fheUintType { + case FheUint4: + lhs_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("4 bit binary op deserialization failed") + } + rhs_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((rhs.serialization))) + if rhs_ptr == nil { + C.destroy_fhe_uint4(lhs_ptr) + return nil, errors.New("4 bit binary op deserialization failed") + } + first_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((first.serialization))) + if first_ptr == nil { + C.destroy_fhe_uint4(lhs_ptr) + C.destroy_fhe_uint4(rhs_ptr) + return nil, errors.New("8 bit binary op deserialization failed") + } + res_ptr := op4(first_ptr, lhs_ptr, rhs_ptr) + C.destroy_fhe_uint4(lhs_ptr) + C.destroy_fhe_uint4(rhs_ptr) + if res_ptr == nil { + return nil, errors.New("4 bit binary op failed") + } + ret := C.serialize_fhe_uint4(res_ptr, res_ser) + C.destroy_fhe_uint4(res_ptr) + if ret != 0 { + return nil, errors.New("4 bit binary op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) case FheUint8: lhs_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((lhs.serialization))) if lhs_ptr == nil { @@ -562,28 +795,75 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) default: - panic("binary op unexpected ciphertext type") + panic("ternary op unexpected ciphertext type") } res.computeHash() return res, nil } func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, - op8 func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer, - op16 func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer, - op32 func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer, - op64 func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer) (*TfheCiphertext, error) { + opBool func(lhs unsafe.Pointer, rhs C.bool) (unsafe.Pointer, error), + op4 func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error), + op8 func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error), + op16 func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error), + op32 func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error), + op64 func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error)) (*TfheCiphertext, error) { res := new(TfheCiphertext) res.fheUintType = lhs.fheUintType res_ser := &C.DynamicBuffer{} switch lhs.fheUintType { + case FheBool: + lhs_ptr := C.deserialize_fhe_bool(toDynamicBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("Bool scalar op deserialization failed") + } + scalar := C.bool(rhs == 1) + res_ptr, err := opBool(lhs_ptr, scalar) + if (err != nil) { + return nil, err; + } + C.destroy_fhe_bool(lhs_ptr) + if res_ptr == nil { + return nil, errors.New("Bool scalar op failed") + } + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit scalar op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) + case FheUint4: + lhs_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((lhs.serialization))) + if lhs_ptr == nil { + return nil, errors.New("4 bit scalar op deserialization failed") + } + scalar := C.uint8_t(rhs) + res_ptr, err := op4(lhs_ptr, scalar) + if (err != nil) { + return nil, err; + } + C.destroy_fhe_uint4(lhs_ptr) + if res_ptr == nil { + return nil, errors.New("4 bit scalar op failed") + } + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint4(res_ptr) + if ret != 0 { + return nil, errors.New("4 bit scalar op serialization failed") + } + res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) + C.destroy_dynamic_buffer(res_ser) case FheUint8: lhs_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((lhs.serialization))) if lhs_ptr == nil { return nil, errors.New("8 bit scalar op deserialization failed") } scalar := C.uint8_t(rhs) - res_ptr := op8(lhs_ptr, scalar) + res_ptr, err := op8(lhs_ptr, scalar) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint8(lhs_ptr) if res_ptr == nil { return nil, errors.New("8 bit scalar op failed") @@ -601,7 +881,10 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, return nil, errors.New("16 bit scalar op deserialization failed") } scalar := C.uint16_t(rhs) - res_ptr := op16(lhs_ptr, scalar) + res_ptr, err := op16(lhs_ptr, scalar) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint16(lhs_ptr) if res_ptr == nil { return nil, errors.New("16 bit scalar op failed") @@ -619,7 +902,10 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, return nil, errors.New("32 bit scalar op deserialization failed") } scalar := C.uint32_t(rhs) - res_ptr := op32(lhs_ptr, scalar) + res_ptr, err := op32(lhs_ptr, scalar) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint32(lhs_ptr) if res_ptr == nil { return nil, errors.New("32 bit scalar op failed") @@ -637,7 +923,10 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, return nil, errors.New("64 bit scalar op deserialization failed") } scalar := C.uint64_t(rhs) - res_ptr := op64(lhs_ptr, scalar) + res_ptr, err := op64(lhs_ptr, scalar) + if (err != nil) { + return nil, err; + } C.destroy_fhe_uint64(lhs_ptr) if res_ptr == nil { return nil, errors.New("64 bit scalar op failed") @@ -658,535 +947,676 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, func (lhs *TfheCiphertext) Add(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.add_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.add_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.add_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.add_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.add_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.add_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.add_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.add_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.add_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarAdd(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_add_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_add_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_add_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_add_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_add_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_add_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_add_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_add_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_add_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Sub(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.sub_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.sub_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.sub_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.sub_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.sub_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.sub_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.sub_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.sub_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.sub_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarSub(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_sub_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_sub_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_sub_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_sub_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_sub_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_sub_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_sub_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_sub_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_sub_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Mul(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.mul_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.mul_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.mul_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.mul_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.mul_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.mul_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.mul_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.mul_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.mul_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarMul(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_mul_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_mul_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_mul_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_mul_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_mul_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_mul_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_mul_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_mul_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_mul_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarDiv(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_div_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_div_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_div_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_div_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_div_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_div_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_div_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_div_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_div_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarRem(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_rem_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_rem_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_rem_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_rem_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_rem_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_rem_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_rem_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_rem_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_rem_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Bitand(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitand_fhe_uint8(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitand_fhe_bool(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitand_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitand_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitand_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitand_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitand_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitand_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitand_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitand_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Bitor(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitor_fhe_uint8(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitor_fhe_bool(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitor_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitor_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitor_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitor_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitor_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitor_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitor_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitor_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Bitxor(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitxor_fhe_uint8(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitxor_fhe_bool(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitxor_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitxor_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitxor_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitxor_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitxor_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitxor_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.bitxor_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.bitxor_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Shl(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shl_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shl_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shl_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shl_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shl_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shl_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shl_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shl_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shl_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarShl(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_shl_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_shl_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_shl_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_shl_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_shl_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_shl_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_shl_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_shl_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_shl_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Shr(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shr_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shr_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shr_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shr_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shr_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shr_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.shr_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shr_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.shr_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarShr(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_shr_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_shr_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_shr_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_shr_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_shr_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_shr_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_shr_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_shr_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_shr_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.eq_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.eq_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.eq_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.eq_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.eq_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.eq_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.eq_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.eq_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.eq_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarEq(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_eq_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_eq_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_eq_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_eq_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_eq_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_eq_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_eq_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_eq_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_eq_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Ne(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ne_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ne_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ne_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ne_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ne_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ne_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ne_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ne_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ne_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarNe(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_ne_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_ne_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_ne_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_ne_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_ne_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_ne_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_ne_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_ne_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_ne_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Ge(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ge_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ge_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ge_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ge_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ge_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ge_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ge_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.ge_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.ge_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarGe(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_ge_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_ge_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_ge_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_ge_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_ge_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_ge_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_ge_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_ge_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_ge_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Gt(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.gt_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.gt_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.gt_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.gt_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.gt_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.gt_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.gt_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.gt_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.gt_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarGt(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_gt_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_gt_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_gt_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_gt_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_gt_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_gt_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_gt_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_gt_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_gt_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Le(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.le_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.le_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.le_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.le_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.le_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.le_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.le_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.le_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.le_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarLe(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_le_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_le_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_le_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_le_fhe_uint8(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_le_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_le_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_le_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_le_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_le_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Lt(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.lt_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.lt_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.lt_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.lt_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.lt_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.lt_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.lt_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.lt_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.lt_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarLt(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_lt_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_lt_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_lt_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_lt_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_lt_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_lt_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_lt_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_lt_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_lt_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Min(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.min_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.min_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.min_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.min_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.min_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.min_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.min_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.min_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.min_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarMin(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_min_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_min_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_min_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_min_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_min_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_min_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_min_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_min_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_min_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Max(rhs *TfheCiphertext) (*TfheCiphertext, error) { return lhs.executeBinaryCiphertextOperation(rhs, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.max_fhe_uint8(lhs, rhs, sks) + boolBinaryNotSupportedOp, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.max_fhe_uint4(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.max_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.max_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.max_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.max_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { - return C.max_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.max_fhe_uint32(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.max_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) ScalarMax(rhs uint64) (*TfheCiphertext, error) { return lhs.executeBinaryScalarOperation(rhs, - func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer { - return C.scalar_max_fhe_uint8(lhs, rhs, sks) + boolBinaryScalarNotSupportedOp, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_max_fhe_uint4(lhs, rhs, sks), nil + }, + func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error) { + return C.scalar_max_fhe_uint8(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer { - return C.scalar_max_fhe_uint16(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error) { + return C.scalar_max_fhe_uint16(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer { - return C.scalar_max_fhe_uint32(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error) { + return C.scalar_max_fhe_uint32(lhs, rhs, sks), nil }, - func(lhs unsafe.Pointer, rhs C.uint64_t) unsafe.Pointer { - return C.scalar_max_fhe_uint64(lhs, rhs, sks) + func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { + return C.scalar_max_fhe_uint64(lhs, rhs, sks), nil }) } func (lhs *TfheCiphertext) Neg() (*TfheCiphertext, error) { return lhs.executeUnaryCiphertextOperation(lhs, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.neg_fhe_uint8(lhs, sks) + boolUnaryNotSupportedOp, + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.neg_fhe_uint4(lhs, sks), nil }, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.neg_fhe_uint16(lhs, sks) + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.neg_fhe_uint8(lhs, sks), nil }, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.neg_fhe_uint32(lhs, sks) + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.neg_fhe_uint16(lhs, sks), nil }, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.neg_fhe_uint64(lhs, sks) + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.neg_fhe_uint32(lhs, sks), nil + }, + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.neg_fhe_uint64(lhs, sks), nil }) } func (lhs *TfheCiphertext) Not() (*TfheCiphertext, error) { return lhs.executeUnaryCiphertextOperation(lhs, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.not_fhe_uint8(lhs, sks) + boolUnaryNotSupportedOp, + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.not_fhe_uint4(lhs, sks), nil + }, + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.not_fhe_uint8(lhs, sks), nil }, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.not_fhe_uint16(lhs, sks) + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.not_fhe_uint16(lhs, sks), nil }, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.not_fhe_uint32(lhs, sks) + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.not_fhe_uint32(lhs, sks), nil }, - func(lhs unsafe.Pointer) unsafe.Pointer { - return C.not_fhe_uint64(lhs, sks) + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.not_fhe_uint64(lhs, sks), nil }) } func (condition *TfheCiphertext) IfThenElse(lhs *TfheCiphertext, rhs *TfheCiphertext) (*TfheCiphertext, error) { return condition.executeTernaryCiphertextOperation(lhs, rhs, + func(condition unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { + return C.if_then_else_fhe_uint4(condition, lhs, rhs, sks) + }, func(condition unsafe.Pointer, lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer { return C.if_then_else_fhe_uint8(condition, lhs, rhs, sks) }, @@ -1210,8 +1640,178 @@ func (ct *TfheCiphertext) CastTo(castToType FheUintType) (*TfheCiphertext, error res.fheUintType = castToType switch ct.fheUintType { + case FheBool: + switch castToType { + case FheUint4: + from_ptr := C.deserialize_fhe_bool(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheBool ciphertext") + } + to_ptr := C.cast_bool_4(from_ptr, sks) + C.destroy_fhe_bool(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheBool to FheUint8") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint4(to_ptr) + if err != nil { + return nil, err + } + case FheUint8: + from_ptr := C.deserialize_fhe_bool(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheBool ciphertext") + } + to_ptr := C.cast_bool_8(from_ptr, sks) + C.destroy_fhe_bool(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheBool to FheUint8") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint8(to_ptr) + if err != nil { + return nil, err + } + case FheUint16: + from_ptr := C.deserialize_fhe_bool(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheBool ciphertext") + } + to_ptr := C.cast_bool_16(from_ptr, sks) + C.destroy_fhe_bool(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheBool to FheUint16") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint16(to_ptr) + if err != nil { + return nil, err + } + case FheUint32: + from_ptr := C.deserialize_fhe_bool(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheBool ciphertext") + } + to_ptr := C.cast_bool_32(from_ptr, sks) + C.destroy_fhe_bool(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheBool to FheUint32") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint32(to_ptr) + if err != nil { + return nil, err + } + case FheUint64: + from_ptr := C.deserialize_fhe_bool(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheBool ciphertext") + } + to_ptr := C.cast_bool_64(from_ptr, sks) + C.destroy_fhe_bool(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheBool to FheUint64") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint64(to_ptr) + if err != nil { + return nil, err + } + default: + panic("castTo: unexpected type to cast to") + } + case FheUint4: + switch castToType { + case FheUint8: + from_ptr := C.deserialize_fhe_uint4(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint4 ciphertext") + } + to_ptr := C.cast_4_8(from_ptr, sks) + C.destroy_fhe_uint4(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint4 to FheUint16") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint8(to_ptr) + if err != nil { + return nil, err + } + case FheUint16: + from_ptr := C.deserialize_fhe_uint4(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint4 ciphertext") + } + to_ptr := C.cast_4_16(from_ptr, sks) + C.destroy_fhe_uint4(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint4 to FheUint16") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint16(to_ptr) + if err != nil { + return nil, err + } + case FheUint32: + from_ptr := C.deserialize_fhe_uint4(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint4 ciphertext") + } + to_ptr := C.cast_4_32(from_ptr, sks) + C.destroy_fhe_uint4(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint4 to FheUint32") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint32(to_ptr) + if err != nil { + return nil, err + } + case FheUint64: + from_ptr := C.deserialize_fhe_uint4(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint4 ciphertext") + } + to_ptr := C.cast_4_64(from_ptr, sks) + C.destroy_fhe_uint4(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint4 to FheUint64") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint64(to_ptr) + if err != nil { + return nil, err + } + default: + panic("castTo: unexpected type to cast to") + } case FheUint8: switch castToType { + case FheUint4: + from_ptr := C.deserialize_fhe_uint8(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint8 ciphertext") + } + to_ptr := C.cast_8_4(from_ptr, sks) + C.destroy_fhe_uint8(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint8 to FheUint4") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint4(to_ptr) + if err != nil { + return nil, err + } case FheUint16: from_ptr := C.deserialize_fhe_uint8(toDynamicBufferView(ct.serialization)) if from_ptr == nil { @@ -1265,6 +1865,22 @@ func (ct *TfheCiphertext) CastTo(castToType FheUintType) (*TfheCiphertext, error } case FheUint16: switch castToType { + case FheUint4: + from_ptr := C.deserialize_fhe_uint16(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint16 ciphertext") + } + to_ptr := C.cast_16_4(from_ptr, sks) + C.destroy_fhe_uint16(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint16 to FheUint4") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint4(to_ptr) + if err != nil { + return nil, err + } case FheUint8: from_ptr := C.deserialize_fhe_uint16(toDynamicBufferView(ct.serialization)) if from_ptr == nil { @@ -1318,6 +1934,22 @@ func (ct *TfheCiphertext) CastTo(castToType FheUintType) (*TfheCiphertext, error } case FheUint32: switch castToType { + case FheUint4: + from_ptr := C.deserialize_fhe_uint32(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint32 ciphertext") + } + to_ptr := C.cast_32_4(from_ptr, sks) + C.destroy_fhe_uint32(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint32 to FheUint4") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint4(to_ptr) + if err != nil { + return nil, err + } case FheUint8: from_ptr := C.deserialize_fhe_uint32(toDynamicBufferView(ct.serialization)) if from_ptr == nil { @@ -1371,6 +2003,22 @@ func (ct *TfheCiphertext) CastTo(castToType FheUintType) (*TfheCiphertext, error } case FheUint64: switch castToType { + case FheUint4: + from_ptr := C.deserialize_fhe_uint64(toDynamicBufferView(ct.serialization)) + if from_ptr == nil { + return nil, errors.New("castTo failed to deserialize FheUint64 ciphertext") + } + to_ptr := C.cast_64_4(from_ptr, sks) + C.destroy_fhe_uint64(from_ptr) + if to_ptr == nil { + return nil, errors.New("castTo failed to cast FheUint64 to FheUint4") + } + var err error + res.serialization, err = serialize(to_ptr, castToType) + C.destroy_fhe_uint4(to_ptr) + if err != nil { + return nil, err + } case FheUint8: from_ptr := C.deserialize_fhe_uint64(toDynamicBufferView(ct.serialization)) if from_ptr == nil { @@ -1434,6 +2082,28 @@ func (ct *TfheCiphertext) Decrypt() (big.Int, error) { var value uint64 var ret C.int switch ct.fheUintType { + case FheBool: + ptr := C.deserialize_fhe_bool(toDynamicBufferView(ct.serialization)) + if ptr == nil { + return *new(big.Int).SetUint64(0), errors.New("failed to deserialize FheBool") + } + var result C.bool + ret = C.decrypt_fhe_bool(cks, ptr, &result) + C.destroy_fhe_bool(ptr) + if (result) { + value = 1 + } else { + value = 0 + } + case FheUint4: + ptr := C.deserialize_fhe_uint4(toDynamicBufferView(ct.serialization)) + if ptr == nil { + return *new(big.Int).SetUint64(0), errors.New("failed to deserialize FheUint4") + } + var result C.uint8_t + ret = C.decrypt_fhe_uint4(cks, ptr, &result) + C.destroy_fhe_uint4(ptr) + value = uint64(result) case FheUint8: ptr := C.deserialize_fhe_uint8(toDynamicBufferView(ct.serialization)) if ptr == nil { diff --git a/fhevm/tfhe_key_management.go b/fhevm/tfhe_key_management.go index d4d4d3c..0f2045e 100644 --- a/fhevm/tfhe_key_management.go +++ b/fhevm/tfhe_key_management.go @@ -66,11 +66,15 @@ func initCiphertextSizes() { expandedFheCiphertextSize = make(map[FheUintType]uint) compactFheCiphertextSize = make(map[FheUintType]uint) + expandedFheCiphertextSize[FheBool] = uint(len(new(TfheCiphertext).TrivialEncrypt(*big.NewInt(0), FheBool).Serialize())) + expandedFheCiphertextSize[FheUint4] = uint(len(new(TfheCiphertext).TrivialEncrypt(*big.NewInt(0), FheUint4).Serialize())) expandedFheCiphertextSize[FheUint8] = uint(len(new(TfheCiphertext).TrivialEncrypt(*big.NewInt(0), FheUint8).Serialize())) expandedFheCiphertextSize[FheUint16] = uint(len(new(TfheCiphertext).TrivialEncrypt(*big.NewInt(0), FheUint16).Serialize())) expandedFheCiphertextSize[FheUint32] = uint(len(new(TfheCiphertext).TrivialEncrypt(*big.NewInt(0), FheUint32).Serialize())) expandedFheCiphertextSize[FheUint64] = uint(len(new(TfheCiphertext).TrivialEncrypt(*big.NewInt(0), FheUint64).Serialize())) + compactFheCiphertextSize[FheBool] = uint(len(encryptAndSerializeCompact(0, FheBool))) + compactFheCiphertextSize[FheUint4] = uint(len(encryptAndSerializeCompact(0, FheUint4))) compactFheCiphertextSize[FheUint8] = uint(len(encryptAndSerializeCompact(0, FheUint8))) compactFheCiphertextSize[FheUint16] = uint(len(encryptAndSerializeCompact(0, FheUint16))) compactFheCiphertextSize[FheUint32] = uint(len(encryptAndSerializeCompact(0, FheUint32))) diff --git a/fhevm/tfhe_test.go b/fhevm/tfhe_test.go index 6caa7b0..bded90a 100644 --- a/fhevm/tfhe_test.go +++ b/fhevm/tfhe_test.go @@ -25,6 +25,10 @@ func TestMain(m *testing.M) { func TfheEncryptDecrypt(t *testing.T, fheUintType FheUintType) { var val big.Int switch fheUintType { + case FheBool: + val.SetUint64(1) + case FheUint4: + val.SetUint64(2) case FheUint8: val.SetUint64(2) case FheUint16: @@ -45,6 +49,10 @@ func TfheEncryptDecrypt(t *testing.T, fheUintType FheUintType) { func TfheTrivialEncryptDecrypt(t *testing.T, fheUintType FheUintType) { var val big.Int switch fheUintType { + case FheBool: + val.SetUint64(1) + case FheUint4: + val.SetUint64(2) case FheUint8: val.SetUint64(2) case FheUint16: @@ -65,6 +73,10 @@ func TfheTrivialEncryptDecrypt(t *testing.T, fheUintType FheUintType) { func TfheSerializeDeserialize(t *testing.T, fheUintType FheUintType) { var val big.Int switch fheUintType { + case FheBool: + val = *big.NewInt(1) + case FheUint4: + val = *big.NewInt(2) case FheUint8: val = *big.NewInt(2) case FheUint16: @@ -91,6 +103,10 @@ func TfheSerializeDeserialize(t *testing.T, fheUintType FheUintType) { func TfheSerializeDeserializeCompact(t *testing.T, fheUintType FheUintType) { var val uint64 switch fheUintType { + case FheBool: + val = 1 + case FheUint4: + val = 2 case FheUint8: val = 2 case FheUint16: @@ -129,6 +145,10 @@ func TfheSerializeDeserializeCompact(t *testing.T, fheUintType FheUintType) { func TfheTrivialSerializeDeserialize(t *testing.T, fheUintType FheUintType) { var val big.Int switch fheUintType { + case FheBool: + val = *big.NewInt(1) + case FheUint4: + val = *big.NewInt(2) case FheUint8: val = *big.NewInt(2) case FheUint16: @@ -165,6 +185,10 @@ func TfheDeserializeFailure(t *testing.T, fheUintType FheUintType) { func TfheDeserializeCompact(t *testing.T, fheUintType FheUintType) { var val uint64 switch fheUintType { + case FheBool: + val = 1 + case FheUint4: + val = 2 case FheUint8: val = 2 case FheUint16: @@ -197,6 +221,9 @@ func TfheDeserializeCompactFailure(t *testing.T, fheUintType FheUintType) { func TfheAdd(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -225,6 +252,9 @@ func TfheAdd(t *testing.T, fheUintType FheUintType) { func TfheScalarAdd(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -251,6 +281,9 @@ func TfheScalarAdd(t *testing.T, fheUintType FheUintType) { func TfheSub(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -279,6 +312,9 @@ func TfheSub(t *testing.T, fheUintType FheUintType) { func TfheScalarSub(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -305,6 +341,9 @@ func TfheScalarSub(t *testing.T, fheUintType FheUintType) { func TfheMul(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -333,6 +372,9 @@ func TfheMul(t *testing.T, fheUintType FheUintType) { func TfheScalarMul(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -359,6 +401,9 @@ func TfheScalarMul(t *testing.T, fheUintType FheUintType) { func TfheScalarDiv(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(4) + b.SetUint64(2) case FheUint8: a.SetUint64(4) b.SetUint64(2) @@ -385,6 +430,9 @@ func TfheScalarDiv(t *testing.T, fheUintType FheUintType) { func TfheScalarRem(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(4) + b.SetUint64(2) case FheUint8: a.SetUint64(4) b.SetUint64(2) @@ -411,6 +459,12 @@ func TfheScalarRem(t *testing.T, fheUintType FheUintType) { func TfheBitAnd(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheBool: + a.SetUint64(1) + b.SetUint64(1) + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -439,6 +493,9 @@ func TfheBitAnd(t *testing.T, fheUintType FheUintType) { func TfheBitOr(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -467,6 +524,9 @@ func TfheBitOr(t *testing.T, fheUintType FheUintType) { func TfheBitXor(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -495,6 +555,9 @@ func TfheBitXor(t *testing.T, fheUintType FheUintType) { func TfheShl(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -523,6 +586,9 @@ func TfheShl(t *testing.T, fheUintType FheUintType) { func TfheScalarShl(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -549,6 +615,9 @@ func TfheScalarShl(t *testing.T, fheUintType FheUintType) { func TfheShr(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -577,6 +646,9 @@ func TfheShr(t *testing.T, fheUintType FheUintType) { func TfheScalarShr(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -603,6 +675,9 @@ func TfheScalarShr(t *testing.T, fheUintType FheUintType) { func TfheEq(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(2) case FheUint8: a.SetUint64(2) b.SetUint64(2) @@ -637,6 +712,9 @@ func TfheEq(t *testing.T, fheUintType FheUintType) { func TfheScalarEq(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -669,6 +747,9 @@ func TfheScalarEq(t *testing.T, fheUintType FheUintType) { func TfheNe(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(2) case FheUint8: a.SetUint64(2) b.SetUint64(2) @@ -703,6 +784,9 @@ func TfheNe(t *testing.T, fheUintType FheUintType) { func TfheScalarNe(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -735,6 +819,9 @@ func TfheScalarNe(t *testing.T, fheUintType FheUintType) { func TfheGe(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -767,6 +854,9 @@ func TfheGe(t *testing.T, fheUintType FheUintType) { func TfheScalarGe(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -792,6 +882,9 @@ func TfheScalarGe(t *testing.T, fheUintType FheUintType) { func TfheGt(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -824,6 +917,9 @@ func TfheGt(t *testing.T, fheUintType FheUintType) { func TfheScalarGt(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -849,6 +945,9 @@ func TfheScalarGt(t *testing.T, fheUintType FheUintType) { func TfheLe(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -881,6 +980,9 @@ func TfheLe(t *testing.T, fheUintType FheUintType) { func TfheScalarLe(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -906,6 +1008,9 @@ func TfheScalarLe(t *testing.T, fheUintType FheUintType) { func TfheLt(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -938,6 +1043,9 @@ func TfheLt(t *testing.T, fheUintType FheUintType) { func TfheScalarLt(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -963,6 +1071,9 @@ func TfheScalarLt(t *testing.T, fheUintType FheUintType) { func TfheMin(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -995,6 +1106,9 @@ func TfheMin(t *testing.T, fheUintType FheUintType) { func TfheScalarMin(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -1020,6 +1134,9 @@ func TfheScalarMin(t *testing.T, fheUintType FheUintType) { func TfheMax(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(4) + b.SetUint64(2) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -1052,6 +1169,9 @@ func TfheMax(t *testing.T, fheUintType FheUintType) { func TfheScalarMax(t *testing.T, fheUintType FheUintType) { var a, b big.Int switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -1079,6 +1199,9 @@ func TfheNeg(t *testing.T, fheUintType FheUintType) { var expected uint64 switch fheUintType { + case FheUint4: + a.SetUint64(2) + expected = uint64(uint8(16 - a.Uint64())) case FheUint8: a.SetUint64(2) expected = uint64(-uint8(a.Uint64())) @@ -1105,6 +1228,9 @@ func TfheNot(t *testing.T, fheUintType FheUintType) { var a big.Int var expected uint64 switch fheUintType { + case FheUint4: + a.SetUint64(2) + expected = uint64(^uint8(a.Uint64())) case FheUint8: a.SetUint64(2) expected = uint64(^uint8(a.Uint64())) @@ -1133,6 +1259,9 @@ func TfheIfThenElse(t *testing.T, fheUintType FheUintType) { condition.SetUint64(1) condition2.SetUint64(0) switch fheUintType { + case FheUint4: + a.SetUint64(2) + b.SetUint64(1) case FheUint8: a.SetUint64(2) b.SetUint64(1) @@ -1169,6 +1298,8 @@ func TfheIfThenElse(t *testing.T, fheUintType FheUintType) { func TfheCast(t *testing.T, fheUintTypeFrom FheUintType, fheUintTypeTo FheUintType) { var a big.Int switch fheUintTypeFrom { + case FheUint4: + a.SetUint64(2) case FheUint8: a.SetUint64(2) case FheUint16: @@ -1181,6 +1312,8 @@ func TfheCast(t *testing.T, fheUintTypeFrom FheUintType, fheUintTypeTo FheUintTy var modulus uint64 switch fheUintTypeTo { + case FheUint4: + modulus = uint64(math.Pow(2, 4)) case FheUint8: modulus = uint64(math.Pow(2, 8)) case FheUint16: @@ -1208,6 +1341,14 @@ func TfheCast(t *testing.T, fheUintTypeFrom FheUintType, fheUintTypeTo FheUintTy } } +func TestTfheEncryptDecryptBool(t *testing.T) { + TfheEncryptDecrypt(t, FheBool) +} + +func TestTfheEncryptDecrypt4(t *testing.T) { + TfheEncryptDecrypt(t, FheUint4) +} + func TestTfheEncryptDecrypt8(t *testing.T) { TfheEncryptDecrypt(t, FheUint8) } @@ -1224,6 +1365,14 @@ func TestTfheEncryptDecrypt64(t *testing.T) { TfheEncryptDecrypt(t, FheUint64) } +func TestTfheTrivialEncryptDecryptBool(t *testing.T) { + TfheTrivialEncryptDecrypt(t, FheBool) +} + +func TestTfheTrivialEncryptDecrypt4(t *testing.T) { + TfheTrivialEncryptDecrypt(t, FheUint4) +} + func TestTfheTrivialEncryptDecrypt8(t *testing.T) { TfheTrivialEncryptDecrypt(t, FheUint8) } @@ -1240,6 +1389,14 @@ func TestTfheTrivialEncryptDecrypt64(t *testing.T) { TfheTrivialEncryptDecrypt(t, FheUint64) } +func TestTfheSerializeDeserializeBool(t *testing.T) { + TfheSerializeDeserialize(t, FheBool) +} + +func TestTfheSerializeDeserialize4(t *testing.T) { + TfheSerializeDeserialize(t, FheUint4) +} + func TestTfheSerializeDeserialize8(t *testing.T) { TfheSerializeDeserialize(t, FheUint8) } @@ -1256,8 +1413,12 @@ func TestTfheSerializeDeserialize64(t *testing.T) { TfheSerializeDeserialize(t, FheUint64) } -func TestTfheSerializeDeserializeCompact8(t *testing.T) { - TfheSerializeDeserializeCompact(t, FheUint8) +func TestTfheSerializeDeserializeCompactBool(t *testing.T) { + TfheSerializeDeserializeCompact(t, FheBool) +} + +func TestTfheSerializeDeserializeCompact4(t *testing.T) { + TfheSerializeDeserializeCompact(t, FheUint4) } func TestTfheSerializeDeserializeCompact16(t *testing.T) { @@ -1272,6 +1433,14 @@ func TestTfheSerializeDeserializeCompact64(t *testing.T) { TfheSerializeDeserializeCompact(t, FheUint64) } +func TestTfheTrivialSerializeDeserializeBool(t *testing.T) { + TfheTrivialSerializeDeserialize(t, FheBool) +} + +func TestTfheTrivialSerializeDeserialize4(t *testing.T) { + TfheTrivialSerializeDeserialize(t, FheUint4) +} + func TestTfheTrivialSerializeDeserialize8(t *testing.T) { TfheTrivialSerializeDeserialize(t, FheUint8) } @@ -1288,6 +1457,14 @@ func TestTfheTrivialSerializeDeserialize64(t *testing.T) { TfheTrivialSerializeDeserialize(t, FheUint64) } +func TestTfheDeserializeFailureBool(t *testing.T) { + TfheDeserializeFailure(t, FheBool) +} + +func TestTfheDeserializeFailure4(t *testing.T) { + TfheDeserializeFailure(t, FheUint4) +} + func TestTfheDeserializeFailure8(t *testing.T) { TfheDeserializeFailure(t, FheUint8) } @@ -1304,6 +1481,14 @@ func TestTfheDeserializeFailure64(t *testing.T) { TfheDeserializeFailure(t, FheUint64) } +func TestTfheDeserializeCompactBool(t *testing.T) { + TfheDeserializeCompact(t, FheBool) +} + +func TestTfheDeserializeCompact4(t *testing.T) { + TfheDeserializeCompact(t, FheUint4) +} + func TestTfheDeserializeCompact8(t *testing.T) { TfheDeserializeCompact(t, FheUint8) } @@ -1320,6 +1505,14 @@ func TestTfheDeserializeCompact64(t *testing.T) { TfheDeserializeCompact(t, FheUint64) } +func TestTfheDeserializeCompactFailureBool(t *testing.T) { + TfheDeserializeCompactFailure(t, FheBool) +} + +func TestTfheDeserializeCompactFailure4(t *testing.T) { + TfheDeserializeCompactFailure(t, FheUint4) +} + func TestTfheDeserializeCompactFailure8(t *testing.T) { TfheDeserializeCompactFailure(t, FheUint8) } @@ -1336,6 +1529,10 @@ func TestTfheDeserializeCompatcFailure64(t *testing.T) { TfheDeserializeCompactFailure(t, FheUint64) } +func TestTfheAdd4(t *testing.T) { + TfheAdd(t, FheUint4) +} + func TestTfheAdd8(t *testing.T) { TfheAdd(t, FheUint8) } @@ -1352,6 +1549,10 @@ func TestTfheAdd64(t *testing.T) { TfheAdd(t, FheUint64) } +func TestTfheScalarAdd4(t *testing.T) { + TfheScalarAdd(t, FheUint4) +} + func TestTfheScalarAdd8(t *testing.T) { TfheScalarAdd(t, FheUint8) } @@ -1368,6 +1569,10 @@ func TestTfheScalarAdd64(t *testing.T) { TfheScalarAdd(t, FheUint32) } +func TestTfheSub4(t *testing.T) { + TfheSub(t, FheUint4) +} + func TestTfheSub8(t *testing.T) { TfheSub(t, FheUint8) } @@ -1384,6 +1589,10 @@ func TestTfheSub64(t *testing.T) { TfheSub(t, FheUint64) } +func TestTfheScalarSub4(t *testing.T) { + TfheScalarSub(t, FheUint4) +} + func TestTfheScalarSub8(t *testing.T) { TfheScalarSub(t, FheUint8) } @@ -1400,6 +1609,10 @@ func TestTfheScalarSub64(t *testing.T) { TfheScalarSub(t, FheUint64) } +func TestTfheMul4(t *testing.T) { + TfheMul(t, FheUint4) +} + func TestTfheMul8(t *testing.T) { TfheMul(t, FheUint8) } @@ -1416,6 +1629,10 @@ func TestTfheMul64(t *testing.T) { TfheMul(t, FheUint64) } +func TestTfheScalarMul4(t *testing.T) { + TfheScalarMul(t, FheUint4) +} + func TestTfheScalarMul8(t *testing.T) { TfheScalarMul(t, FheUint8) } @@ -1432,6 +1649,10 @@ func TestTfheScalarMul64(t *testing.T) { TfheScalarMul(t, FheUint64) } +func TestTfheScalarDiv4(t *testing.T) { + TfheScalarDiv(t, FheUint4) +} + func TestTfheScalarDiv8(t *testing.T) { TfheScalarDiv(t, FheUint8) } @@ -1448,6 +1669,10 @@ func TestTfheScalarDiv64(t *testing.T) { TfheScalarDiv(t, FheUint64) } +func TestTfheScalarRem4(t *testing.T) { + TfheScalarRem(t, FheUint4) +} + func TestTfheScalarRem8(t *testing.T) { TfheScalarRem(t, FheUint8) } @@ -1464,6 +1689,10 @@ func TestTfheScalarRem64(t *testing.T) { TfheScalarRem(t, FheUint64) } +func TestTfheBitAnd4(t *testing.T) { + TfheBitAnd(t, FheUint4) +} + func TestTfheBitAnd8(t *testing.T) { TfheBitAnd(t, FheUint8) } @@ -1480,6 +1709,10 @@ func TestTfheBitAnd64(t *testing.T) { TfheBitAnd(t, FheUint64) } +func TestTfheBitOr4(t *testing.T) { + TfheBitOr(t, FheUint4) +} + func TestTfheBitOr8(t *testing.T) { TfheBitOr(t, FheUint8) } @@ -1496,6 +1729,10 @@ func TestTfheBitOr64(t *testing.T) { TfheBitOr(t, FheUint64) } +func TestTfheBitXor4(t *testing.T) { + TfheBitXor(t, FheUint4) +} + func TestTfheBitXor8(t *testing.T) { TfheBitXor(t, FheUint8) } @@ -1512,6 +1749,10 @@ func TestTfheBitXor64(t *testing.T) { TfheBitXor(t, FheUint64) } +func TestTfheShl4(t *testing.T) { + TfheShl(t, FheUint4) +} + func TestTfheShl8(t *testing.T) { TfheShl(t, FheUint8) } @@ -1528,6 +1769,10 @@ func TestTfheShl64(t *testing.T) { TfheShl(t, FheUint64) } +func TestTfheScalarShl4(t *testing.T) { + TfheScalarShl(t, FheUint4) +} + func TestTfheScalarShl8(t *testing.T) { TfheScalarShl(t, FheUint8) } @@ -1544,6 +1789,10 @@ func TestTfheScalarShl64(t *testing.T) { TfheScalarShl(t, FheUint64) } +func TestTfheShr4(t *testing.T) { + TfheShr(t, FheUint4) +} + func TestTfheShr8(t *testing.T) { TfheShr(t, FheUint8) } @@ -1576,6 +1825,10 @@ func TestTfheScalarShr64(t *testing.T) { TfheScalarShr(t, FheUint64) } +func TestTfheEq4(t *testing.T) { + TfheEq(t, FheUint4) +} + func TestTfheEq8(t *testing.T) { TfheEq(t, FheUint8) } @@ -1592,6 +1845,10 @@ func TestTfheEq64(t *testing.T) { TfheEq(t, FheUint64) } +func TestTfheScalarEq4(t *testing.T) { + TfheScalarEq(t, FheUint4) +} + func TestTfheScalarEq8(t *testing.T) { TfheScalarEq(t, FheUint8) } @@ -1608,6 +1865,10 @@ func TestTfheScalarEq64(t *testing.T) { TfheScalarEq(t, FheUint64) } +func TestTfheNe4(t *testing.T) { + TfheNe(t, FheUint8) +} + func TestTfheNe8(t *testing.T) { TfheNe(t, FheUint8) } @@ -1624,6 +1885,10 @@ func TestTfheNe64(t *testing.T) { TfheNe(t, FheUint64) } +func TestTfheScalarNe4(t *testing.T) { + TfheScalarNe(t, FheUint4) +} + func TestTfheScalarNe8(t *testing.T) { TfheScalarNe(t, FheUint8) } @@ -1640,6 +1905,10 @@ func TestTfheScalarNe64(t *testing.T) { TfheScalarNe(t, FheUint64) } +func TestTfheGe4(t *testing.T) { + TfheGe(t, FheUint4) +} + func TestTfheGe8(t *testing.T) { TfheGe(t, FheUint8) } @@ -1656,6 +1925,10 @@ func TestTfheGe64(t *testing.T) { TfheGe(t, FheUint64) } +func TestTfheScalarGe4(t *testing.T) { + TfheScalarGe(t, FheUint4) +} + func TestTfheScalarGe8(t *testing.T) { TfheScalarGe(t, FheUint8) } @@ -1672,6 +1945,10 @@ func TestTfheScalarGe64(t *testing.T) { TfheScalarGe(t, FheUint64) } +func TestTfheGt4(t *testing.T) { + TfheGt(t, FheUint4) +} + func TestTfheGt8(t *testing.T) { TfheGt(t, FheUint8) } @@ -1688,6 +1965,10 @@ func TestTfheGt64(t *testing.T) { TfheGt(t, FheUint64) } +func TestTfheScalarGt4(t *testing.T) { + TfheScalarGt(t, FheUint4) +} + func TestTfheScalarGt8(t *testing.T) { TfheScalarGt(t, FheUint8) } @@ -1704,6 +1985,10 @@ func TestTfheScalarGt64(t *testing.T) { TfheScalarGt(t, FheUint64) } +func TestTfheLe4(t *testing.T) { + TfheLe(t, FheUint4) +} + func TestTfheLe8(t *testing.T) { TfheLe(t, FheUint8) } @@ -1720,6 +2005,10 @@ func TestTfheLe64(t *testing.T) { TfheLe(t, FheUint64) } +func TestTfheScalarLe4(t *testing.T) { + TfheScalarLe(t, FheUint4) +} + func TestTfheScalarLe8(t *testing.T) { TfheScalarLe(t, FheUint8) } @@ -1736,6 +2025,10 @@ func TestTfheScalarLe64(t *testing.T) { TfheScalarLe(t, FheUint64) } +func TestTfheLt4(t *testing.T) { + TfheLt(t, FheUint4) +} + func TestTfheLt8(t *testing.T) { TfheLt(t, FheUint8) } @@ -1750,6 +2043,10 @@ func TestTfheLt64(t *testing.T) { TfheLt(t, FheUint64) } +func TestTfheScalarLt4(t *testing.T) { + TfheScalarLt(t, FheUint4) +} + func TestTfheScalarLt8(t *testing.T) { TfheScalarLt(t, FheUint8) } @@ -1766,6 +2063,10 @@ func TestTfheScalarLt64(t *testing.T) { TfheScalarLt(t, FheUint64) } +func TestTfheMin4(t *testing.T) { + TfheMin(t, FheUint4) +} + func TestTfheMin8(t *testing.T) { TfheMin(t, FheUint8) } @@ -1780,6 +2081,10 @@ func TestTfheMin64(t *testing.T) { TfheMin(t, FheUint64) } +func TestTfheScalarMin4(t *testing.T) { + TfheScalarMin(t, FheUint4) +} + func TestTfheScalarMin8(t *testing.T) { TfheScalarMin(t, FheUint8) } @@ -1796,6 +2101,10 @@ func TestTfheScalarMin64(t *testing.T) { TfheScalarMin(t, FheUint64) } +func TestTfheMax4(t *testing.T) { + TfheMax(t, FheUint4) +} + func TestTfheMax8(t *testing.T) { TfheMax(t, FheUint8) } @@ -1810,6 +2119,10 @@ func TestTfheMax64(t *testing.T) { TfheMax(t, FheUint64) } +func TestTfheScalarMax4(t *testing.T) { + TfheScalarMax(t, FheUint4) +} + func TestTfheScalarMax8(t *testing.T) { TfheScalarMax(t, FheUint8) } @@ -1826,6 +2139,10 @@ func TestTfheScalarMax64(t *testing.T) { TfheScalarMax(t, FheUint64) } +func TestTfheNeg4(t *testing.T) { + TfheNeg(t, FheUint4) +} + func TestTfheNeg8(t *testing.T) { TfheNeg(t, FheUint8) } @@ -1840,6 +2157,10 @@ func TestTfheNeg64(t *testing.T) { TfheNeg(t, FheUint64) } +func TestTfheNot4(t *testing.T) { + TfheNot(t, FheUint8) +} + func TestTfheNot8(t *testing.T) { TfheNot(t, FheUint8) } @@ -1854,6 +2175,10 @@ func TestTfheNot64(t *testing.T) { TfheNot(t, FheUint64) } +func TestTfheIfThenElse4(t *testing.T) { + TfheIfThenElse(t, FheUint4) +} + func TestTfheIfThenElse8(t *testing.T) { TfheIfThenElse(t, FheUint8) } @@ -1868,8 +2193,28 @@ func TestTfheIfThenElse64(t *testing.T) { TfheIfThenElse(t, FheUint64) } +func TestTfhe4Cast8(t *testing.T) { + TfheCast(t, FheUint4, FheUint8) +} + +func TestTfhe4Cast16(t *testing.T) { + TfheCast(t, FheUint4, FheUint16) +} + +func TestTfhe4Cast32(t *testing.T) { + TfheCast(t, FheUint4, FheUint32) +} + +func TestTfhe4Cast64(t *testing.T) { + TfheCast(t, FheUint4, FheUint64) +} + +func TestTfhe8Cast4(t *testing.T) { + TfheCast(t, FheUint8, FheUint4) +} + func TestTfhe8Cast16(t *testing.T) { - TfheCast(t, FheUint8, FheUint16) + TfheCast(t, FheUint4, FheUint16) } func TestTfhe8Cast32(t *testing.T) { @@ -1880,6 +2225,10 @@ func TestTfhe8Cast64(t *testing.T) { TfheCast(t, FheUint8, FheUint64) } +func TestTfhe16Cast4(t *testing.T) { + TfheCast(t, FheUint16, FheUint4) +} + func TestTfhe16Cast8(t *testing.T) { TfheCast(t, FheUint16, FheUint8) } @@ -1892,6 +2241,10 @@ func TestTfhe16Cast64(t *testing.T) { TfheCast(t, FheUint16, FheUint64) } +func TestTfhe32Cast4(t *testing.T) { + TfheCast(t, FheUint32, FheUint4) +} + func TestTfhe32Cast8(t *testing.T) { TfheCast(t, FheUint32, FheUint8) } @@ -1904,6 +2257,10 @@ func TestTfhe32Cast64(t *testing.T) { TfheCast(t, FheUint32, FheUint64) } +func TestTfhe64Cast4(t *testing.T) { + TfheCast(t, FheUint64, FheUint4) +} + func TestTfhe64Cast8(t *testing.T) { TfheCast(t, FheUint64, FheUint8) } diff --git a/fhevm/tfhe_wrappers.c b/fhevm/tfhe_wrappers.c index 090cb90..1567f26 100644 --- a/fhevm/tfhe_wrappers.c +++ b/fhevm/tfhe_wrappers.c @@ -53,6 +53,26 @@ void checked_set_server_key(void *sks) { assert(r == 0); } +void* cast_4_bool(void* ct, void* sks) { + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_ne(ct, 0, &result); + if(r != 0) return NULL; + return result; +} + +void* cast_bool_4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_bool_cast_into_fhe_uint4(ct, &result); + if(r != 0) return NULL; + return result; +} + void* cast_8_bool(void* ct, void* sks) { FheBool* result = NULL; @@ -103,6 +123,82 @@ void* cast_bool_64(void* ct, void* sks) { return result; } +int serialize_fhe_bool(void *ct, DynamicBuffer* out) { + return fhe_bool_serialize(ct, out); +} + +void* deserialize_fhe_bool(DynamicBufferView in) { + FheBool* ct = NULL; + const int r = fhe_bool_deserialize(in, &ct); + if(r != 0) { + return NULL; + } + return ct; +} + +void* deserialize_compact_fhe_bool(DynamicBufferView in) { + CompactFheBoolList* list = NULL; + FheBool* ct = NULL; + + int r = compact_fhe_bool_list_deserialize(in, &list); + if(r != 0) { + return NULL; + } + size_t len = 0; + r = compact_fhe_bool_list_len(list, &len); + // Expect only 1 ciphertext in the list. + if(r != 0 || len != 1) { + r = compact_fhe_bool_list_destroy(list); + assert(r == 0); + return NULL; + } + r = compact_fhe_bool_list_expand(list, &ct, 1); + if(r != 0) { + ct = NULL; + } + r = compact_fhe_bool_list_destroy(list); + assert(r == 0); + return ct; +} + +int serialize_fhe_uint4(void *ct, DynamicBuffer* out) { + return fhe_uint4_serialize(ct, out); +} + +void* deserialize_fhe_uint4(DynamicBufferView in) { + FheUint4* ct = NULL; + const int r = fhe_uint4_deserialize(in, &ct); + if(r != 0) { + return NULL; + } + return ct; +} + +void* deserialize_compact_fhe_uint4(DynamicBufferView in) { + CompactFheUint4List* list = NULL; + FheUint4* ct = NULL; + + int r = compact_fhe_uint4_list_deserialize(in, &list); + if(r != 0) { + return NULL; + } + size_t len = 0; + r = compact_fhe_uint4_list_len(list, &len); + // Expect only 1 ciphertext in the list. + if(r != 0 || len != 1) { + r = compact_fhe_uint4_list_destroy(list); + assert(r == 0); + return NULL; + } + r = compact_fhe_uint4_list_expand(list, &ct, 1); + if(r != 0) { + ct = NULL; + } + r = compact_fhe_uint4_list_destroy(list); + assert(r == 0); + return ct; +} + int serialize_fhe_uint8(void *ct, DynamicBuffer* out) { return fhe_uint8_serialize(ct, out); } @@ -255,6 +351,16 @@ void* deserialize_compact_fhe_uint64(DynamicBufferView in) { return ct; } +void destroy_fhe_bool(void* ct) { + const int r = fhe_bool_destroy(ct); + assert(r == 0); +} + +void destroy_fhe_uint4(void* ct) { + const int r = fhe_uint4_destroy(ct); + assert(r == 0); +} + void destroy_fhe_uint8(void* ct) { const int r = fhe_uint8_destroy(ct); assert(r == 0); @@ -275,6 +381,17 @@ void destroy_fhe_uint64(void* ct) { assert(r == 0); } +void* add_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_add(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* add_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -319,6 +436,17 @@ void* add_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_add_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_add(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_add_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -363,6 +491,17 @@ void* scalar_add_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* sub_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_sub(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* sub_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -407,6 +546,17 @@ void* sub_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_sub_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_sub(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_sub_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -451,6 +601,17 @@ void* scalar_sub_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* mul_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_mul(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* mul_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -495,6 +656,17 @@ void* mul_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_mul_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_mul(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_mul_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -539,6 +711,17 @@ void* scalar_mul_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* scalar_div_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_div(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_div_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -583,6 +766,17 @@ void* scalar_div_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* scalar_rem_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_rem(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_rem_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -627,6 +821,28 @@ void* scalar_rem_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* bitand_fhe_bool(void* ct1, void* ct2, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_bool_bitand(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* bitand_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_bitand(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* bitand_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -671,6 +887,28 @@ void* bitand_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* bitor_fhe_bool(void* ct1, void* ct2, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_bool_bitor(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* bitor_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_bitor(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* bitor_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -715,6 +953,28 @@ void* bitor_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* bitxor_fhe_bool(void* ct1, void* ct2, void* sks) +{ + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_bool_bitxor(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + +void* bitxor_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_bitxor(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* bitxor_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -759,6 +1019,17 @@ void* bitxor_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* shl_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_shl(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* shl_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -803,6 +1074,17 @@ void* shl_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_shl_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_shl(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_shl_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -847,6 +1129,17 @@ void* scalar_shl_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* shr_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_shr(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* shr_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -891,6 +1184,17 @@ void* shr_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_shr_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_shr(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_shr_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -935,6 +1239,18 @@ void* scalar_shr_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* eq_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_eq(ct1, ct2, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* eq_fhe_uint8(void* ct1, void* ct2, void* sks) { FheBool* bool_result = NULL; @@ -983,6 +1299,18 @@ void* eq_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_eq_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_eq(ct, pt, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* scalar_eq_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheBool* bool_result = NULL; @@ -1031,6 +1359,18 @@ void* scalar_eq_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* ne_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_ne(ct1, ct2, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* ne_fhe_uint8(void* ct1, void* ct2, void* sks) { FheBool* bool_result = NULL; @@ -1079,6 +1419,18 @@ void* ne_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_ne_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_ne(ct, pt, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* scalar_ne_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheBool* bool_result = NULL; @@ -1127,6 +1479,18 @@ void* scalar_ne_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* ge_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_ge(ct1, ct2, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* ge_fhe_uint8(void* ct1, void* ct2, void* sks) { FheBool* bool_result = NULL; @@ -1175,6 +1539,18 @@ void* ge_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_ge_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_ge(ct, pt, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* scalar_ge_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheBool* bool_result = NULL; @@ -1223,6 +1599,18 @@ void* scalar_ge_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* gt_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_gt(ct1, ct2, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* gt_fhe_uint8(void* ct1, void* ct2, void* sks) { FheBool* bool_result = NULL; @@ -1271,6 +1659,18 @@ void* gt_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_gt_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_gt(ct, pt, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* scalar_gt_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheBool* bool_result = NULL; @@ -1319,6 +1719,18 @@ void* scalar_gt_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* le_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_le(ct1, ct2, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* le_fhe_uint8(void* ct1, void* ct2, void* sks) { FheBool* bool_result = NULL; @@ -1367,6 +1779,18 @@ void* le_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_le_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_le(ct, pt, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* scalar_le_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheBool* bool_result = NULL; @@ -1415,6 +1839,18 @@ void* scalar_le_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* lt_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_lt(ct1, ct2, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* lt_fhe_uint8(void* ct1, void* ct2, void* sks) { FheBool* bool_result = NULL; @@ -1463,6 +1899,18 @@ void* lt_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_lt_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheBool* bool_result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_lt(ct, pt, &bool_result); + if(r != 0) return NULL; + FheUint4* result = cast_bool_4(bool_result, sks); + return result; +} + void* scalar_lt_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheBool* bool_result = NULL; @@ -1511,6 +1959,17 @@ void* scalar_lt_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* min_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_min(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* min_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -1555,6 +2014,17 @@ void* min_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_min_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_min(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_min_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -1599,6 +2069,17 @@ void* scalar_min_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* max_fhe_uint4(void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_max(ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* max_fhe_uint8(void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -1643,6 +2124,17 @@ void* max_fhe_uint64(void* ct1, void* ct2, void* sks) return result; } +void* scalar_max_fhe_uint4(void* ct, uint8_t pt, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_scalar_max(ct, pt, &result); + if(r != 0) return NULL; + return result; +} + void* scalar_max_fhe_uint8(void* ct, uint8_t pt, void* sks) { FheUint8* result = NULL; @@ -1687,6 +2179,16 @@ void* scalar_max_fhe_uint64(void* ct, uint64_t pt, void* sks) return result; } +void* neg_fhe_uint4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_neg(ct, &result); + if(r != 0) return NULL; + return result; +} + void* neg_fhe_uint8(void* ct, void* sks) { FheUint8* result = NULL; @@ -1727,6 +2229,26 @@ void* neg_fhe_uint64(void* ct, void* sks) { return result; } +void* not_fhe_bool(void* ct, void* sks) { + FheBool* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_bool_not(ct, &result); + if(r != 0) return NULL; + return result; +} + +void* not_fhe_uint4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_not(ct, &result); + if(r != 0) return NULL; + return result; +} + void* not_fhe_uint8(void* ct, void* sks) { FheUint8* result = NULL; @@ -1767,6 +2289,19 @@ void* not_fhe_uint64(void* ct, void* sks) { return result; } +void* if_then_else_fhe_uint4(void* condition, void* ct1, void* ct2, void* sks) +{ + FheUint4* result = NULL; + + checked_set_server_key(sks); + + FheBool* cond = cast_4_bool(condition, sks); + + const int r = fhe_uint4_if_then_else(cond, ct1, ct2, &result); + if(r != 0) return NULL; + return result; +} + void* if_then_else_fhe_uint8(void* condition, void* ct1, void* ct2, void* sks) { FheUint8* result = NULL; @@ -1819,6 +2354,18 @@ void* if_then_else_fhe_uint64(void* condition, void* ct1, void* ct2, void* sks) return result; } +int decrypt_fhe_bool(void* cks, void* ct, bool* res) +{ + *res = false; + return fhe_bool_decrypt(ct, cks, res); +} + +int decrypt_fhe_uint4(void* cks, void* ct, uint8_t* res) +{ + *res = 0; + return fhe_uint4_decrypt(ct, cks, res); +} + int decrypt_fhe_uint8(void* cks, void* ct, uint8_t* res) { *res = 0; @@ -1843,6 +2390,38 @@ int decrypt_fhe_uint64(void* cks, void* ct, uint64_t* res) return fhe_uint64_decrypt(ct, cks, res); } +void* public_key_encrypt_fhe_bool(void* pks, bool value) { + CompactFheBoolList* list = NULL; + FheBool* ct = NULL; + + int r = compact_fhe_bool_list_try_encrypt_with_compact_public_key_bool(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_bool_list_expand(list, &ct, 1); + assert(r == 0); + + r = compact_fhe_bool_list_destroy(list); + assert(r == 0); + + return ct; +} + +void* public_key_encrypt_fhe_uint4(void* pks, uint8_t value) { + CompactFheUint4List* list = NULL; + FheUint4* ct = NULL; + + int r = compact_fhe_uint4_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint4_list_expand(list, &ct, 1); + assert(r == 0); + + r = compact_fhe_uint4_list_destroy(list); + assert(r == 0); + + return ct; +} + void* public_key_encrypt_fhe_uint8(void* pks, uint8_t value) { CompactFheUint8List* list = NULL; FheUint8* ct = NULL; @@ -1907,6 +2486,28 @@ void* public_key_encrypt_fhe_uint64(void* pks, uint64_t value) { return ct; } +void* trivial_encrypt_fhe_bool(void* sks, bool value) { + FheBool* ct = NULL; + + checked_set_server_key(sks); + + int r = fhe_bool_try_encrypt_trivial_bool(value, &ct); + assert(r == 0); + + return ct; +} + +void* trivial_encrypt_fhe_uint4(void* sks, uint8_t value) { + FheUint4* ct = NULL; + + checked_set_server_key(sks); + + int r = fhe_uint4_try_encrypt_trivial_u8(value, &ct); + assert(r == 0); + + return ct; +} + void* trivial_encrypt_fhe_uint8(void* sks, uint8_t value) { FheUint8* ct = NULL; @@ -1951,6 +2552,32 @@ void* trivial_encrypt_fhe_uint64(void* sks, uint64_t value) { return ct; } +void public_key_encrypt_and_serialize_fhe_bool_list(void* pks, bool value, DynamicBuffer* out) { + CompactFheBoolList* list = NULL; + + int r = compact_fhe_bool_list_try_encrypt_with_compact_public_key_bool(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_bool_list_serialize(list, out); + assert(r == 0); + + r = compact_fhe_bool_list_destroy(list); + assert(r == 0); +} + +void public_key_encrypt_and_serialize_fhe_uint4_list(void* pks, uint8_t value, DynamicBuffer* out) { + CompactFheUint4List* list = NULL; + + int r = compact_fhe_uint4_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint4_list_serialize(list, out); + assert(r == 0); + + r = compact_fhe_uint4_list_destroy(list); + assert(r == 0); +} + void public_key_encrypt_and_serialize_fhe_uint8_list(void* pks, uint8_t value, DynamicBuffer* out) { CompactFheUint8List* list = NULL; @@ -2003,6 +2630,57 @@ void public_key_encrypt_and_serialize_fhe_uint64_list(void* pks, uint64_t value, assert(r == 0); } + +void* cast_4_8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_cast_into_fhe_uint8(ct, &result); + if(r != 0) return NULL; + return result; +} + +void* cast_4_16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_cast_into_fhe_uint16(ct, &result); + if(r != 0) return NULL; + return result; +} + +void* cast_4_32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_cast_into_fhe_uint32(ct, &result); + if(r != 0) return NULL; + return result; +} + +void* cast_4_64(void* ct, void* sks) { + FheUint64* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint4_cast_into_fhe_uint64(ct, &result); + if(r != 0) return NULL; + return result; +} + +void* cast_8_4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint4(ct, &result); + if(r != 0) return NULL; + return result; +} + void* cast_8_16(void* ct, void* sks) { FheUint16* result = NULL; @@ -2033,6 +2711,16 @@ void* cast_8_64(void* ct, void* sks) { return result; } +void* cast_16_4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_cast_into_fhe_uint4(ct, &result); + if(r != 0) return NULL; + return result; +} + void* cast_16_8(void* ct, void* sks) { FheUint8* result = NULL; @@ -2063,6 +2751,16 @@ void* cast_16_64(void* ct, void* sks) { return result; } +void* cast_32_4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_cast_into_fhe_uint4(ct, &result); + if(r != 0) return NULL; + return result; +} + void* cast_32_8(void* ct, void* sks) { FheUint8* result = NULL; @@ -2093,6 +2791,16 @@ void* cast_32_64(void* ct, void* sks) { return result; } +void* cast_64_4(void* ct, void* sks) { + FheUint4* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint64_cast_into_fhe_uint4(ct, &result); + if(r != 0) return NULL; + return result; +} + void* cast_64_8(void* ct, void* sks) { FheUint8* result = NULL; diff --git a/fhevm/tfhe_wrappers.go b/fhevm/tfhe_wrappers.go index bfb2f89..4d6b509 100644 --- a/fhevm/tfhe_wrappers.go +++ b/fhevm/tfhe_wrappers.go @@ -28,6 +28,10 @@ func serialize(ptr unsafe.Pointer, t FheUintType) ([]byte, error) { out := &C.DynamicBuffer{} var ret C.int switch t { + case FheBool: + ret = C.serialize_fhe_bool(ptr, out) + case FheUint4: + ret = C.serialize_fhe_uint4(ptr, out) case FheUint8: ret = C.serialize_fhe_uint8(ptr, out) case FheUint16: @@ -64,6 +68,14 @@ func serializePublicKey() ([]byte, error) { func encryptAndSerializeCompact(value uint64, fheUintType FheUintType) []byte { out := &C.DynamicBuffer{} switch fheUintType { + case FheBool: + val := false + if value == 1 { + val = true + } + C.public_key_encrypt_and_serialize_fhe_bool_list(pks, C.bool(val), out) + case FheUint4: + C.public_key_encrypt_and_serialize_fhe_uint4_list(pks, C.uint8_t(value), out) case FheUint8: C.public_key_encrypt_and_serialize_fhe_uint8_list(pks, C.uint8_t(value), out) case FheUint16: diff --git a/fhevm/tfhe_wrappers.h b/fhevm/tfhe_wrappers.h index 3be8b1f..78588a5 100644 --- a/fhevm/tfhe_wrappers.h +++ b/fhevm/tfhe_wrappers.h @@ -19,15 +19,17 @@ void* deserialize_compact_public_key(DynamicBufferView in); void checked_set_server_key(void *sks); -void* cast_8_bool(void* ct, void* sks); +int serialize_fhe_bool(void *ct, DynamicBuffer* out); -void* cast_bool_8(void* ct, void* sks); +void* deserialize_fhe_bool(DynamicBufferView in); -void* cast_bool_16(void* ct, void* sks); +void* deserialize_compact_fhe_bool(DynamicBufferView in); -void* cast_bool_32(void* ct, void* sks); +int serialize_fhe_uint4(void *ct, DynamicBuffer* out); -void* cast_bool_64(void* ct, void* sks); +void* deserialize_fhe_uint4(DynamicBufferView in); + +void* deserialize_compact_fhe_uint4(DynamicBufferView in); int serialize_fhe_uint8(void *ct, DynamicBuffer* out); @@ -53,6 +55,10 @@ void* deserialize_fhe_uint64(DynamicBufferView in); void* deserialize_compact_fhe_uint64(DynamicBufferView in); +void destroy_fhe_bool(void* ct); + +void destroy_fhe_uint4(void* ct); + void destroy_fhe_uint8(void* ct); void destroy_fhe_uint16(void* ct); @@ -61,6 +67,8 @@ void destroy_fhe_uint32(void* ct); void destroy_fhe_uint64(void* ct); +void* add_fhe_uint4(void* ct1, void* ct2, void* sks); + void* add_fhe_uint8(void* ct1, void* ct2, void* sks); void* add_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -69,6 +77,8 @@ void* add_fhe_uint32(void* ct1, void* ct2, void* sks); void* add_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_add_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_add_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_add_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -77,6 +87,8 @@ void* scalar_add_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_add_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* sub_fhe_uint4(void* ct1, void* ct2, void* sks); + void* sub_fhe_uint8(void* ct1, void* ct2, void* sks); void* sub_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -85,6 +97,8 @@ void* sub_fhe_uint32(void* ct1, void* ct2, void* sks); void* sub_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_sub_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_sub_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_sub_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -93,6 +107,8 @@ void* scalar_sub_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_sub_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* mul_fhe_uint4(void* ct1, void* ct2, void* sks); + void* mul_fhe_uint8(void* ct1, void* ct2, void* sks); void* mul_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -101,6 +117,8 @@ void* mul_fhe_uint32(void* ct1, void* ct2, void* sks); void* mul_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_mul_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_mul_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_mul_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -109,6 +127,8 @@ void* scalar_mul_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_mul_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* scalar_div_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_div_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_div_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -117,6 +137,8 @@ void* scalar_div_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_div_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* scalar_rem_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_rem_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_rem_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -125,6 +147,10 @@ void* scalar_rem_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_rem_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* bitand_fhe_bool(void* ct1, void* ct2, void* sks); + +void* bitand_fhe_uint4(void* ct1, void* ct2, void* sks); + void* bitand_fhe_uint8(void* ct1, void* ct2, void* sks); void* bitand_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -133,6 +159,10 @@ void* bitand_fhe_uint32(void* ct1, void* ct2, void* sks); void* bitand_fhe_uint64(void* ct1, void* ct2, void* sks); +void* bitor_fhe_bool(void* ct1, void* ct2, void* sks); + +void* bitor_fhe_uint4(void* ct1, void* ct2, void* sks); + void* bitor_fhe_uint8(void* ct1, void* ct2, void* sks); void* bitor_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -141,6 +171,10 @@ void* bitor_fhe_uint32(void* ct1, void* ct2, void* sks); void* bitor_fhe_uint64(void* ct1, void* ct2, void* sks); +void* bitxor_fhe_bool(void* ct1, void* ct2, void* sks); + +void* bitxor_fhe_uint4(void* ct1, void* ct2, void* sks); + void* bitxor_fhe_uint8(void* ct1, void* ct2, void* sks); void* bitxor_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -149,6 +183,8 @@ void* bitxor_fhe_uint32(void* ct1, void* ct2, void* sks); void* bitxor_fhe_uint64(void* ct1, void* ct2, void* sks); +void* shl_fhe_uint4(void* ct1, void* ct2, void* sks); + void* shl_fhe_uint8(void* ct1, void* ct2, void* sks); void* shl_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -157,6 +193,8 @@ void* shl_fhe_uint32(void* ct1, void* ct2, void* sks); void* shl_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_shl_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_shl_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_shl_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -165,6 +203,8 @@ void* scalar_shl_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_shl_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* shr_fhe_uint4(void* ct1, void* ct2, void* sks); + void* shr_fhe_uint8(void* ct1, void* ct2, void* sks); void* shr_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -173,6 +213,8 @@ void* shr_fhe_uint32(void* ct1, void* ct2, void* sks); void* shr_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_shr_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_shr_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_shr_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -181,6 +223,8 @@ void* scalar_shr_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_shr_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* eq_fhe_uint4(void* ct1, void* ct2, void* sks); + void* eq_fhe_uint8(void* ct1, void* ct2, void* sks); void* eq_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -189,6 +233,8 @@ void* eq_fhe_uint32(void* ct1, void* ct2, void* sks); void* eq_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_eq_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_eq_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_eq_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -197,6 +243,8 @@ void* scalar_eq_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_eq_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* ne_fhe_uint4(void* ct1, void* ct2, void* sks); + void* ne_fhe_uint8(void* ct1, void* ct2, void* sks); void* ne_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -205,6 +253,8 @@ void* ne_fhe_uint32(void* ct1, void* ct2, void* sks); void* ne_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_ne_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_ne_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_ne_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -213,6 +263,8 @@ void* scalar_ne_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_ne_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* ge_fhe_uint4(void* ct1, void* ct2, void* sks); + void* ge_fhe_uint8(void* ct1, void* ct2, void* sks); void* ge_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -221,6 +273,8 @@ void* ge_fhe_uint32(void* ct1, void* ct2, void* sks); void* ge_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_ge_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_ge_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_ge_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -229,6 +283,8 @@ void* scalar_ge_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_ge_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* gt_fhe_uint4(void* ct1, void* ct2, void* sks); + void* gt_fhe_uint8(void* ct1, void* ct2, void* sks); void* gt_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -237,6 +293,8 @@ void* gt_fhe_uint32(void* ct1, void* ct2, void* sks); void* gt_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_gt_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_gt_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_gt_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -245,6 +303,8 @@ void* scalar_gt_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_gt_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* le_fhe_uint4(void* ct1, void* ct2, void* sks); + void* le_fhe_uint8(void* ct1, void* ct2, void* sks); void* le_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -253,6 +313,8 @@ void* le_fhe_uint32(void* ct1, void* ct2, void* sks); void* le_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_le_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_le_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_le_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -261,6 +323,8 @@ void* scalar_le_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_le_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* lt_fhe_uint4(void* ct1, void* ct2, void* sks); + void* lt_fhe_uint8(void* ct1, void* ct2, void* sks); void* lt_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -269,6 +333,8 @@ void* lt_fhe_uint32(void* ct1, void* ct2, void* sks); void* lt_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_lt_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_lt_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_lt_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -277,6 +343,8 @@ void* scalar_lt_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_lt_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* min_fhe_uint4(void* ct1, void* ct2, void* sks); + void* min_fhe_uint8(void* ct1, void* ct2, void* sks); void* min_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -285,6 +353,8 @@ void* min_fhe_uint32(void* ct1, void* ct2, void* sks); void* min_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_min_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_min_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_min_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -293,6 +363,8 @@ void* scalar_min_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_min_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* max_fhe_uint4(void* ct1, void* ct2, void* sks); + void* max_fhe_uint8(void* ct1, void* ct2, void* sks); void* max_fhe_uint16(void* ct1, void* ct2, void* sks); @@ -301,6 +373,8 @@ void* max_fhe_uint32(void* ct1, void* ct2, void* sks); void* max_fhe_uint64(void* ct1, void* ct2, void* sks); +void* scalar_max_fhe_uint4(void* ct, uint8_t pt, void* sks); + void* scalar_max_fhe_uint8(void* ct, uint8_t pt, void* sks); void* scalar_max_fhe_uint16(void* ct, uint16_t pt, void* sks); @@ -309,6 +383,8 @@ void* scalar_max_fhe_uint32(void* ct, uint32_t pt, void* sks); void* scalar_max_fhe_uint64(void* ct, uint64_t pt, void* sks); +void* neg_fhe_uint4(void* ct, void* sks); + void* neg_fhe_uint8(void* ct, void* sks); void* neg_fhe_uint16(void* ct, void* sks); @@ -317,6 +393,8 @@ void* neg_fhe_uint32(void* ct, void* sks); void* neg_fhe_uint64(void* ct, void* sks); +void* not_fhe_uint4(void* ct, void* sks); + void* not_fhe_uint8(void* ct, void* sks); void* not_fhe_uint16(void* ct, void* sks); @@ -325,6 +403,8 @@ void* not_fhe_uint32(void* ct, void* sks); void* not_fhe_uint64(void* ct, void* sks); +void* if_then_else_fhe_uint4(void* condition, void* ct1, void* ct2, void* sks); + void* if_then_else_fhe_uint8(void* condition, void* ct1, void* ct2, void* sks); void* if_then_else_fhe_uint16(void* condition, void* ct1, void* ct2, void* sks); @@ -333,6 +413,10 @@ void* if_then_else_fhe_uint32(void* condition, void* ct1, void* ct2, void* sks); void* if_then_else_fhe_uint64(void* condition, void* ct1, void* ct2, void* sks); +int decrypt_fhe_bool(void* cks, void* ct, bool* res); + +int decrypt_fhe_uint4(void* cks, void* ct, uint8_t* res); + int decrypt_fhe_uint8(void* cks, void* ct, uint8_t* res); int decrypt_fhe_uint16(void* cks, void* ct, uint16_t* res); @@ -341,6 +425,10 @@ int decrypt_fhe_uint32(void* cks, void* ct, uint32_t* res); int decrypt_fhe_uint64(void* cks, void* ct, uint64_t* res); +void* public_key_encrypt_fhe_bool(void* pks, bool value); + +void* public_key_encrypt_fhe_uint4(void* pks, uint8_t value); + void* public_key_encrypt_fhe_uint8(void* pks, uint8_t value); void* public_key_encrypt_fhe_uint16(void* pks, uint16_t value); @@ -349,6 +437,10 @@ void* public_key_encrypt_fhe_uint32(void* pks, uint32_t value); void* public_key_encrypt_fhe_uint64(void* pks, uint64_t value); +void* trivial_encrypt_fhe_bool(void* sks, bool value); + +void* trivial_encrypt_fhe_uint4(void* sks, uint8_t value); + void* trivial_encrypt_fhe_uint8(void* sks, uint8_t value); void* trivial_encrypt_fhe_uint16(void* sks, uint16_t value); @@ -357,6 +449,10 @@ void* trivial_encrypt_fhe_uint32(void* sks, uint32_t value); void* trivial_encrypt_fhe_uint64(void* sks, uint64_t value); +void public_key_encrypt_and_serialize_fhe_bool_list(void* pks, bool value, DynamicBuffer* out); + +void public_key_encrypt_and_serialize_fhe_uint4_list(void* pks, uint8_t value, DynamicBuffer* out); + void public_key_encrypt_and_serialize_fhe_uint8_list(void* pks, uint8_t value, DynamicBuffer* out); void public_key_encrypt_and_serialize_fhe_uint16_list(void* pks, uint16_t value, DynamicBuffer* out); @@ -365,24 +461,60 @@ void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, void public_key_encrypt_and_serialize_fhe_uint64_list(void* pks, uint64_t value, DynamicBuffer* out); +void* cast_bool_4(void* ct, void* sks); + +void* cast_bool_8(void* ct, void* sks); + +void* cast_bool_16(void* ct, void* sks); + +void* cast_bool_32(void* ct, void* sks); + +void* cast_bool_64(void* ct, void* sks); + +void* cast_4_bool(void* ct, void* sks); + +void* cast_4_8(void* ct, void* sks); + +void* cast_4_16(void* ct, void* sks); + +void* cast_4_32(void* ct, void* sks); + +void* cast_4_64(void* ct, void* sks); + +void* cast_4_bool(void* ct, void* sks); + +void* cast_8_4(void* ct, void* sks); + void* cast_8_16(void* ct, void* sks); void* cast_8_32(void* ct, void* sks); void* cast_8_64(void* ct, void* sks); +void* cast_16_bool(void* ct, void* sks); + +void* cast_16_4(void* ct, void* sks); + void* cast_16_8(void* ct, void* sks); void* cast_16_32(void* ct, void* sks); void* cast_16_64(void* ct, void* sks); +void* cast_32_bool(void* ct, void* sks); + +void* cast_32_4(void* ct, void* sks); + void* cast_32_8(void* ct, void* sks); void* cast_32_16(void* ct, void* sks); void* cast_32_64(void* ct, void* sks); +void* cast_64_bool(void* ct, void* sks); + +void* cast_64_4(void* ct, void* sks); + void* cast_64_8(void* ct, void* sks); void* cast_64_16(void* ct, void* sks); diff --git a/kms/kms.pb.go b/kms/kms.pb.go index 6c467ce..a294afb 100644 --- a/kms/kms.pb.go +++ b/kms/kms.pb.go @@ -24,27 +24,30 @@ type FheType int32 const ( FheType_Bool FheType = 0 - FheType_Euint8 FheType = 1 - FheType_Euint16 FheType = 2 - FheType_Euint32 FheType = 3 - FheType_Euint64 FheType = 4 + FheType_Euint4 FheType = 1 + FheType_Euint8 FheType = 2 + FheType_Euint16 FheType = 3 + FheType_Euint32 FheType = 4 + FheType_Euint64 FheType = 5 ) // Enum value maps for FheType. var ( FheType_name = map[int32]string{ 0: "Bool", - 1: "Euint8", - 2: "Euint16", - 3: "Euint32", - 4: "Euint64", + 1: "Euint4", + 2: "Euint8", + 3: "Euint16", + 4: "Euint32", + 5: "Euint64", } FheType_value = map[string]int32{ "Bool": 0, - "Euint8": 1, - "Euint16": 2, - "Euint32": 3, - "Euint64": 4, + "Euint4": 1, + "Euint8": 2, + "Euint16": 3, + "Euint32": 4, + "Euint64": 5, } ) diff --git a/tfhe-rs b/tfhe-rs index ad1ae0c..5b65386 160000 --- a/tfhe-rs +++ b/tfhe-rs @@ -1 +1 @@ -Subproject commit ad1ae0c8c206533fcb5668f57a93777251214a59 +Subproject commit 5b653864b7e2c865c85c9d83efb2b85f08f53045 From 183683b88795c493970e4751bbe4e986c3cdc73c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Tue, 27 Feb 2024 17:47:26 +0100 Subject: [PATCH 2/9] fix: fix ifthenelse types --- fhevm/contracts_test.go | 9 +++++++-- fhevm/fhelib_required_gas.go | 2 +- fhevm/tfhe_ciphertext.go | 32 ++++++++++++++++---------------- fhevm/tfhe_test.go | 4 ++-- fhevm/tfhe_wrappers.c | 20 +++++--------------- 5 files changed, 31 insertions(+), 36 deletions(-) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 7cef262..993d530 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -1497,7 +1497,7 @@ func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { environment.depth = depth addr := common.Address{} readOnly := false - firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheUint8).GetHash() + firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheBool).GetHash() secondHash := verifyCiphertextInTestMemory(environment, second, depth, fheUintType).GetHash() thirdHash := verifyCiphertextInTestMemory(environment, third, depth, fheUintType).GetHash() input := toLibPrecompileInputNoScalar(signature, firstHash, secondHash, thirdHash) @@ -2772,7 +2772,7 @@ func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { environment.depth = depth addr := common.Address{} readOnly := false - conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, fheUintType).GetHash() + conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, FheBool).GetHash() lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash() rhsHash := verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash() @@ -3840,6 +3840,11 @@ func TestFheNot64(t *testing.T) { FheNot(t, FheUint64, false) } +func TestFheIfThenElse4(t *testing.T) { + FheIfThenElse(t, FheUint4, 1) + FheIfThenElse(t, FheUint4, 0) +} + func TestFheIfThenElse8(t *testing.T) { FheIfThenElse(t, FheUint8, 1) FheIfThenElse(t, FheUint8, 0) diff --git a/fhevm/fhelib_required_gas.go b/fhevm/fhelib_required_gas.go index 7497a7b..4a3e95d 100644 --- a/fhevm/fhelib_required_gas.go +++ b/fhevm/fhelib_required_gas.go @@ -367,7 +367,7 @@ func fheIfThenElseRequiredGas(environment EVMEnvironment, input []byte) uint64 { logger.Error("IfThenElse op RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input)) return 0 } - if first.fheUintType() != FheUint8 { + if first.fheUintType() != FheBool { logger.Error("IfThenElse op RequiredGas() invalid type for condition", "first", first.fheUintType()) return 0 } diff --git a/fhevm/tfhe_ciphertext.go b/fhevm/tfhe_ciphertext.go index e841f5b..ea93ef4 100644 --- a/fhevm/tfhe_ciphertext.go +++ b/fhevm/tfhe_ciphertext.go @@ -659,11 +659,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte C.destroy_fhe_uint4(lhs_ptr) return nil, errors.New("4 bit binary op deserialization failed") } - first_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((first.serialization))) + first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization))) if first_ptr == nil { C.destroy_fhe_uint4(lhs_ptr) C.destroy_fhe_uint4(rhs_ptr) - return nil, errors.New("8 bit binary op deserialization failed") + return nil, errors.New("Bool binary op deserialization failed") } res_ptr := op4(first_ptr, lhs_ptr, rhs_ptr) C.destroy_fhe_uint4(lhs_ptr) @@ -688,11 +688,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte C.destroy_fhe_uint8(lhs_ptr) return nil, errors.New("8 bit binary op deserialization failed") } - first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization))) + first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization))) if first_ptr == nil { C.destroy_fhe_uint8(lhs_ptr) C.destroy_fhe_uint8(rhs_ptr) - return nil, errors.New("8 bit binary op deserialization failed") + return nil, errors.New("Bool binary op deserialization failed") } res_ptr := op8(first_ptr, lhs_ptr, rhs_ptr) C.destroy_fhe_uint8(lhs_ptr) @@ -717,11 +717,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte C.destroy_fhe_uint16(lhs_ptr) return nil, errors.New("16 bit binary op deserialization failed") } - first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization))) + first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization))) if first_ptr == nil { - C.destroy_fhe_uint8(lhs_ptr) - C.destroy_fhe_uint8(rhs_ptr) - return nil, errors.New("8 bit binary op deserialization failed") + C.destroy_fhe_uint16(lhs_ptr) + C.destroy_fhe_uint16(rhs_ptr) + return nil, errors.New("Bool binary op deserialization failed") } res_ptr := op16(first_ptr, lhs_ptr, rhs_ptr) C.destroy_fhe_uint16(lhs_ptr) @@ -746,11 +746,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte C.destroy_fhe_uint32(lhs_ptr) return nil, errors.New("32 bit binary op deserialization failed") } - first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization))) + first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization))) if first_ptr == nil { - C.destroy_fhe_uint8(lhs_ptr) - C.destroy_fhe_uint8(rhs_ptr) - return nil, errors.New("8 bit binary op deserialization failed") + C.destroy_fhe_uint32(lhs_ptr) + C.destroy_fhe_uint32(rhs_ptr) + return nil, errors.New("Bool binary op deserialization failed") } res_ptr := op32(first_ptr, lhs_ptr, rhs_ptr) C.destroy_fhe_uint32(lhs_ptr) @@ -775,11 +775,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte C.destroy_fhe_uint64(lhs_ptr) return nil, errors.New("64 bit binary op deserialization failed") } - first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization))) + first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization))) if first_ptr == nil { - C.destroy_fhe_uint8(lhs_ptr) - C.destroy_fhe_uint8(rhs_ptr) - return nil, errors.New("8 bit binary op deserialization failed") + C.destroy_fhe_uint64(lhs_ptr) + C.destroy_fhe_uint64(rhs_ptr) + return nil, errors.New("Bool binary op deserialization failed") } res_ptr := op64(first_ptr, lhs_ptr, rhs_ptr) C.destroy_fhe_uint64(lhs_ptr) diff --git a/fhevm/tfhe_test.go b/fhevm/tfhe_test.go index bded90a..f3eba01 100644 --- a/fhevm/tfhe_test.go +++ b/fhevm/tfhe_test.go @@ -1276,9 +1276,9 @@ func TfheIfThenElse(t *testing.T, fheUintType FheUintType) { b.SetUint64(133337) } ctCondition := new(TfheCiphertext) - ctCondition.Encrypt(condition, fheUintType) + ctCondition.Encrypt(condition, FheBool) ctCondition2 := new(TfheCiphertext) - ctCondition2.Encrypt(condition2, fheUintType) + ctCondition2.Encrypt(condition2, FheBool) ctA := new(TfheCiphertext) ctA.Encrypt(a, fheUintType) ctB := new(TfheCiphertext) diff --git a/fhevm/tfhe_wrappers.c b/fhevm/tfhe_wrappers.c index 1567f26..754061f 100644 --- a/fhevm/tfhe_wrappers.c +++ b/fhevm/tfhe_wrappers.c @@ -2295,9 +2295,7 @@ void* if_then_else_fhe_uint4(void* condition, void* ct1, void* ct2, void* sks) checked_set_server_key(sks); - FheBool* cond = cast_4_bool(condition, sks); - - const int r = fhe_uint4_if_then_else(cond, ct1, ct2, &result); + const int r = fhe_uint4_if_then_else(condition, ct1, ct2, &result); if(r != 0) return NULL; return result; } @@ -2308,9 +2306,7 @@ void* if_then_else_fhe_uint8(void* condition, void* ct1, void* ct2, void* sks) checked_set_server_key(sks); - FheBool* cond = cast_8_bool(condition, sks); - - const int r = fhe_uint8_if_then_else(cond, ct1, ct2, &result); + const int r = fhe_uint8_if_then_else(condition, ct1, ct2, &result); if(r != 0) return NULL; return result; } @@ -2321,9 +2317,7 @@ void* if_then_else_fhe_uint16(void* condition, void* ct1, void* ct2, void* sks) checked_set_server_key(sks); - FheBool* cond = cast_8_bool(condition, sks); - - const int r = fhe_uint16_if_then_else(cond, ct1, ct2, &result); + const int r = fhe_uint16_if_then_else(condition, ct1, ct2, &result); if(r != 0) return NULL; return result; } @@ -2334,9 +2328,7 @@ void* if_then_else_fhe_uint32(void* condition, void* ct1, void* ct2, void* sks) checked_set_server_key(sks); - FheBool* cond = cast_8_bool(condition, sks); - - const int r = fhe_uint32_if_then_else(cond, ct1, ct2, &result); + const int r = fhe_uint32_if_then_else(condition, ct1, ct2, &result); if(r != 0) return NULL; return result; } @@ -2347,9 +2339,7 @@ void* if_then_else_fhe_uint64(void* condition, void* ct1, void* ct2, void* sks) checked_set_server_key(sks); - FheBool* cond = cast_8_bool(condition, sks); - - const int r = fhe_uint64_if_then_else(cond, ct1, ct2, &result); + const int r = fhe_uint64_if_then_else(condition, ct1, ct2, &result); if(r != 0) return NULL; return result; } From f49e1e57bc03d073eb40e6dd7b013a2efe7246cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Tue, 27 Feb 2024 19:13:38 +0100 Subject: [PATCH 3/9] fix: fix return type of le/gt/ne/eq/... --- fhevm/contracts_test.go | 6 +- fhevm/tfhe_ciphertext.go | 247 ++++++++++++++++++++++---------- fhevm/tfhe_wrappers.c | 300 ++++++++++++++++----------------------- proto/kms.proto | 9 +- 4 files changed, 298 insertions(+), 264 deletions(-) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 993d530..1bded17 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -463,7 +463,7 @@ func FheLibLe(t *testing.T, fheUintType FheUintType, scalar bool) { } decrypted, err := res.ciphertext.Decrypt() if err != nil || decrypted.Uint64() != 0 { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), err) } // Inverting operands is only possible in the non scalar case as scalar @@ -2439,7 +2439,7 @@ func FheLe(t *testing.T, fheUintType FheUintType, scalar bool) { } decrypted, err := res.ciphertext.Decrypt() if err != nil || decrypted.Uint64() != 0 { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), err) } // Inverting operands is only possible in the non scalar case as scalar @@ -2794,6 +2794,8 @@ func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) { func Decrypt(t *testing.T, fheUintType FheUintType) { var value uint64 switch fheUintType { + case FheBool: + value = 1 case FheUint4: value = 2 case FheUint8: diff --git a/fhevm/tfhe_ciphertext.go b/fhevm/tfhe_ciphertext.go index ea93ef4..fab0755 100644 --- a/fhevm/tfhe_ciphertext.go +++ b/fhevm/tfhe_ciphertext.go @@ -459,17 +459,22 @@ func (ct *TfheCiphertext) executeUnaryCiphertextOperation(rhs *TfheCiphertext, func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, opBool func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), - op4 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), - op8 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op4 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + op8 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), op16 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), op32 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), - op64 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error)) (*TfheCiphertext, error) { + op64 func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error), + returnBool bool) (*TfheCiphertext, error) { if lhs.fheUintType != rhs.fheUintType { return nil, errors.New("binary operations are only well-defined for identical types") } res := new(TfheCiphertext) - res.fheUintType = lhs.fheUintType + if(returnBool) { + res.fheUintType = FheBool + } else { + res.fheUintType = lhs.fheUintType + } res_ser := &C.DynamicBuffer{} switch lhs.fheUintType { case FheBool: @@ -517,10 +522,18 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, if res_ptr == nil { return nil, errors.New("4 bit binary op failed") } - ret := C.serialize_fhe_uint4(res_ptr, res_ser) - C.destroy_fhe_uint4(res_ptr) - if ret != 0 { - return nil, errors.New("4 bit binary op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + } else { + ret := C.serialize_fhe_uint4(res_ptr, res_ser) + C.destroy_fhe_uint4(res_ptr) + if ret != 0 { + return nil, errors.New("4 bit binary op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -543,10 +556,18 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, if res_ptr == nil { return nil, errors.New("8 bit binary op failed") } - ret := C.serialize_fhe_uint8(res_ptr, res_ser) - C.destroy_fhe_uint8(res_ptr) - if ret != 0 { - return nil, errors.New("8 bit binary op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + } else { + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint8(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit binary op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -569,10 +590,18 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, if res_ptr == nil { return nil, errors.New("16 bit binary op failed") } - ret := C.serialize_fhe_uint16(res_ptr, res_ser) - C.destroy_fhe_uint16(res_ptr) - if ret != 0 { - return nil, errors.New("16 bit binary op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + } else { + ret := C.serialize_fhe_uint16(res_ptr, res_ser) + C.destroy_fhe_uint16(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit binary op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -595,10 +624,19 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, if res_ptr == nil { return nil, errors.New("32 bit binary op failed") } - ret := C.serialize_fhe_uint32(res_ptr, res_ser) - C.destroy_fhe_uint32(res_ptr) - if ret != 0 { - return nil, errors.New("32 bit binary op serialization failed") + + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + } else { + ret := C.serialize_fhe_uint32(res_ptr, res_ser) + C.destroy_fhe_uint32(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit binary op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -621,10 +659,18 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, if res_ptr == nil { return nil, errors.New("64 bit binary op failed") } - ret := C.serialize_fhe_uint64(res_ptr, res_ser) - C.destroy_fhe_uint64(res_ptr) - if ret != 0 { - return nil, errors.New("64 bit binary op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("bool binary op serialization failed") + } + } else { + ret := C.serialize_fhe_uint64(res_ptr, res_ser) + C.destroy_fhe_uint64(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit binary op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -807,9 +853,14 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, op8 func(lhs unsafe.Pointer, rhs C.uint8_t) (unsafe.Pointer, error), op16 func(lhs unsafe.Pointer, rhs C.uint16_t) (unsafe.Pointer, error), op32 func(lhs unsafe.Pointer, rhs C.uint32_t) (unsafe.Pointer, error), - op64 func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error)) (*TfheCiphertext, error) { + op64 func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error), + returnBool bool) (*TfheCiphertext, error) { res := new(TfheCiphertext) - res.fheUintType = lhs.fheUintType + if(returnBool) { + res.fheUintType = FheBool + } else { + res.fheUintType = lhs.fheUintType + } res_ser := &C.DynamicBuffer{} switch lhs.fheUintType { case FheBool: @@ -829,7 +880,7 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, ret := C.serialize_fhe_bool(res_ptr, res_ser) C.destroy_fhe_bool(res_ptr) if ret != 0 { - return nil, errors.New("8 bit scalar op serialization failed") + return nil, errors.New("Bool scalar op serialization failed") } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -847,10 +898,18 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if res_ptr == nil { return nil, errors.New("4 bit scalar op failed") } - ret := C.serialize_fhe_uint8(res_ptr, res_ser) - C.destroy_fhe_uint4(res_ptr) - if ret != 0 { - return nil, errors.New("4 bit scalar op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool scalar op serialization failed") + } + } else { + ret := C.serialize_fhe_uint4(res_ptr, res_ser) + C.destroy_fhe_uint4(res_ptr) + if ret != 0 { + return nil, errors.New("4 bit scalar op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -868,10 +927,18 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if res_ptr == nil { return nil, errors.New("8 bit scalar op failed") } - ret := C.serialize_fhe_uint8(res_ptr, res_ser) - C.destroy_fhe_uint8(res_ptr) - if ret != 0 { - return nil, errors.New("8 bit scalar op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool scalar op serialization failed") + } + } else { + ret := C.serialize_fhe_uint8(res_ptr, res_ser) + C.destroy_fhe_uint8(res_ptr) + if ret != 0 { + return nil, errors.New("8 bit scalar op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -889,10 +956,18 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if res_ptr == nil { return nil, errors.New("16 bit scalar op failed") } - ret := C.serialize_fhe_uint16(res_ptr, res_ser) - C.destroy_fhe_uint16(res_ptr) - if ret != 0 { - return nil, errors.New("16 bit scalar op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool scalar op serialization failed") + } + } else { + ret := C.serialize_fhe_uint16(res_ptr, res_ser) + C.destroy_fhe_uint16(res_ptr) + if ret != 0 { + return nil, errors.New("16 bit scalar op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -910,10 +985,18 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if res_ptr == nil { return nil, errors.New("32 bit scalar op failed") } - ret := C.serialize_fhe_uint32(res_ptr, res_ser) - C.destroy_fhe_uint32(res_ptr) - if ret != 0 { - return nil, errors.New("32 bit scalar op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool scalar op serialization failed") + } + } else { + ret := C.serialize_fhe_uint32(res_ptr, res_ser) + C.destroy_fhe_uint32(res_ptr) + if ret != 0 { + return nil, errors.New("32 bit scalar op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -931,10 +1014,18 @@ func (lhs *TfheCiphertext) executeBinaryScalarOperation(rhs uint64, if res_ptr == nil { return nil, errors.New("64 bit scalar op failed") } - ret := C.serialize_fhe_uint64(res_ptr, res_ser) - C.destroy_fhe_uint64(res_ptr) - if ret != 0 { - return nil, errors.New("64 bit scalar op serialization failed") + if (returnBool) { + ret := C.serialize_fhe_bool(res_ptr, res_ser) + C.destroy_fhe_bool(res_ptr) + if ret != 0 { + return nil, errors.New("Bool scalar op serialization failed") + } + } else { + ret := C.serialize_fhe_uint64(res_ptr, res_ser) + C.destroy_fhe_uint64(res_ptr) + if ret != 0 { + return nil, errors.New("64 bit scalar op serialization failed") + } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) C.destroy_dynamic_buffer(res_ser) @@ -962,7 +1053,7 @@ func (lhs *TfheCiphertext) Add(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.add_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarAdd(rhs uint64) (*TfheCiphertext, error) { @@ -982,7 +1073,7 @@ func (lhs *TfheCiphertext) ScalarAdd(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_add_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Sub(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1002,7 +1093,7 @@ func (lhs *TfheCiphertext) Sub(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.sub_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarSub(rhs uint64) (*TfheCiphertext, error) { @@ -1022,7 +1113,7 @@ func (lhs *TfheCiphertext) ScalarSub(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_sub_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Mul(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1042,7 +1133,7 @@ func (lhs *TfheCiphertext) Mul(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.mul_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarMul(rhs uint64) (*TfheCiphertext, error) { @@ -1062,7 +1153,7 @@ func (lhs *TfheCiphertext) ScalarMul(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_mul_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarDiv(rhs uint64) (*TfheCiphertext, error) { @@ -1082,7 +1173,7 @@ func (lhs *TfheCiphertext) ScalarDiv(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_div_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarRem(rhs uint64) (*TfheCiphertext, error) { @@ -1102,7 +1193,7 @@ func (lhs *TfheCiphertext) ScalarRem(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_rem_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Bitand(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1124,7 +1215,7 @@ func (lhs *TfheCiphertext) Bitand(rhs *TfheCiphertext) (*TfheCiphertext, error) }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.bitand_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Bitor(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1146,7 +1237,7 @@ func (lhs *TfheCiphertext) Bitor(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.bitor_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Bitxor(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1168,7 +1259,7 @@ func (lhs *TfheCiphertext) Bitxor(rhs *TfheCiphertext) (*TfheCiphertext, error) }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.bitxor_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Shl(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1188,7 +1279,7 @@ func (lhs *TfheCiphertext) Shl(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.shl_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarShl(rhs uint64) (*TfheCiphertext, error) { @@ -1208,7 +1299,7 @@ func (lhs *TfheCiphertext) ScalarShl(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_shl_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Shr(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1228,7 +1319,7 @@ func (lhs *TfheCiphertext) Shr(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.shr_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarShr(rhs uint64) (*TfheCiphertext, error) { @@ -1248,7 +1339,7 @@ func (lhs *TfheCiphertext) ScalarShr(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_shr_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1268,7 +1359,7 @@ func (lhs *TfheCiphertext) Eq(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.eq_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) ScalarEq(rhs uint64) (*TfheCiphertext, error) { @@ -1288,7 +1379,7 @@ func (lhs *TfheCiphertext) ScalarEq(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_eq_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) Ne(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1308,7 +1399,7 @@ func (lhs *TfheCiphertext) Ne(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.ne_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) ScalarNe(rhs uint64) (*TfheCiphertext, error) { @@ -1328,7 +1419,7 @@ func (lhs *TfheCiphertext) ScalarNe(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_ne_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) Ge(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1348,7 +1439,7 @@ func (lhs *TfheCiphertext) Ge(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.ge_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) ScalarGe(rhs uint64) (*TfheCiphertext, error) { @@ -1368,7 +1459,7 @@ func (lhs *TfheCiphertext) ScalarGe(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_ge_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) Gt(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1388,7 +1479,7 @@ func (lhs *TfheCiphertext) Gt(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.gt_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) ScalarGt(rhs uint64) (*TfheCiphertext, error) { @@ -1408,7 +1499,7 @@ func (lhs *TfheCiphertext) ScalarGt(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_gt_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) Le(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1428,7 +1519,7 @@ func (lhs *TfheCiphertext) Le(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.le_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) ScalarLe(rhs uint64) (*TfheCiphertext, error) { @@ -1449,7 +1540,7 @@ func (lhs *TfheCiphertext) ScalarLe(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_le_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) Lt(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1469,7 +1560,7 @@ func (lhs *TfheCiphertext) Lt(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.lt_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) ScalarLt(rhs uint64) (*TfheCiphertext, error) { @@ -1489,7 +1580,7 @@ func (lhs *TfheCiphertext) ScalarLt(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_lt_fhe_uint64(lhs, rhs, sks), nil - }) + }, true) } func (lhs *TfheCiphertext) Min(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1509,7 +1600,7 @@ func (lhs *TfheCiphertext) Min(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.min_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarMin(rhs uint64) (*TfheCiphertext, error) { @@ -1529,7 +1620,7 @@ func (lhs *TfheCiphertext) ScalarMin(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_min_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Max(rhs *TfheCiphertext) (*TfheCiphertext, error) { @@ -1549,7 +1640,7 @@ func (lhs *TfheCiphertext) Max(rhs *TfheCiphertext) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs unsafe.Pointer) (unsafe.Pointer, error) { return C.max_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) ScalarMax(rhs uint64) (*TfheCiphertext, error) { @@ -1569,7 +1660,7 @@ func (lhs *TfheCiphertext) ScalarMax(rhs uint64) (*TfheCiphertext, error) { }, func(lhs unsafe.Pointer, rhs C.uint64_t) (unsafe.Pointer, error) { return C.scalar_max_fhe_uint64(lhs, rhs, sks), nil - }) + }, false) } func (lhs *TfheCiphertext) Neg() (*TfheCiphertext, error) { diff --git a/fhevm/tfhe_wrappers.c b/fhevm/tfhe_wrappers.c index 754061f..31efeec 100644 --- a/fhevm/tfhe_wrappers.c +++ b/fhevm/tfhe_wrappers.c @@ -1241,721 +1241,661 @@ void* scalar_shr_fhe_uint64(void* ct, uint64_t pt, void* sks) void* eq_fhe_uint4(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_eq(ct1, ct2, &bool_result); + const int r = fhe_uint4_eq(ct1, ct2, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* eq_fhe_uint8(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_eq(ct1, ct2, &bool_result); + const int r = fhe_uint8_eq(ct1, ct2, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* eq_fhe_uint16(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_eq(ct1, ct2, &bool_result); + const int r = fhe_uint16_eq(ct1, ct2, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* eq_fhe_uint32(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_eq(ct1, ct2, &bool_result); + const int r = fhe_uint16_eq(ct1, ct2, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* eq_fhe_uint64(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_eq(ct1, ct2, &bool_result); + const int r = fhe_uint64_eq(ct1, ct2, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* scalar_eq_fhe_uint4(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_scalar_eq(ct, pt, &bool_result); + const int r = fhe_uint4_scalar_eq(ct, pt, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* scalar_eq_fhe_uint8(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_scalar_eq(ct, pt, &bool_result); + const int r = fhe_uint8_scalar_eq(ct, pt, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* scalar_eq_fhe_uint16(void* ct, uint16_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_scalar_eq(ct, pt, &bool_result); + const int r = fhe_uint16_scalar_eq(ct, pt, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* scalar_eq_fhe_uint32(void* ct, uint32_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_scalar_eq(ct, pt, &bool_result); + const int r = fhe_uint32_scalar_eq(ct, pt, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* scalar_eq_fhe_uint64(void* ct, uint64_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_scalar_eq(ct, pt, &bool_result); + const int r = fhe_uint64_scalar_eq(ct, pt, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* ne_fhe_uint4(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_ne(ct1, ct2, &bool_result); + const int r = fhe_uint4_ne(ct1, ct2, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* ne_fhe_uint8(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_ne(ct1, ct2, &bool_result); + const int r = fhe_uint8_ne(ct1, ct2, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* ne_fhe_uint16(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_ne(ct1, ct2, &bool_result); + const int r = fhe_uint16_ne(ct1, ct2, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* ne_fhe_uint32(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_ne(ct1, ct2, &bool_result); + const int r = fhe_uint32_ne(ct1, ct2, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* ne_fhe_uint64(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_ne(ct1, ct2, &bool_result); + const int r = fhe_uint64_ne(ct1, ct2, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* scalar_ne_fhe_uint4(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_scalar_ne(ct, pt, &bool_result); + const int r = fhe_uint4_scalar_ne(ct, pt, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* scalar_ne_fhe_uint8(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_scalar_ne(ct, pt, &bool_result); + const int r = fhe_uint8_scalar_ne(ct, pt, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* scalar_ne_fhe_uint16(void* ct, uint16_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_scalar_ne(ct, pt, &bool_result); + const int r = fhe_uint16_scalar_ne(ct, pt, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* scalar_ne_fhe_uint32(void* ct, uint32_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_scalar_ne(ct, pt, &bool_result); + const int r = fhe_uint32_scalar_ne(ct, pt, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* scalar_ne_fhe_uint64(void* ct, uint64_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_scalar_ne(ct, pt, &bool_result); + const int r = fhe_uint64_scalar_ne(ct, pt, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* ge_fhe_uint4(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_ge(ct1, ct2, &bool_result); + const int r = fhe_uint4_ge(ct1, ct2, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* ge_fhe_uint8(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_ge(ct1, ct2, &bool_result); + const int r = fhe_uint8_ge(ct1, ct2, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* ge_fhe_uint16(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_ge(ct1, ct2, &bool_result); + const int r = fhe_uint16_ge(ct1, ct2, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* ge_fhe_uint32(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_ge(ct1, ct2, &bool_result); + const int r = fhe_uint32_ge(ct1, ct2, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* ge_fhe_uint64(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_ge(ct1, ct2, &bool_result); + const int r = fhe_uint64_ge(ct1, ct2, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* scalar_ge_fhe_uint4(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_scalar_ge(ct, pt, &bool_result); + const int r = fhe_uint4_scalar_ge(ct, pt, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* scalar_ge_fhe_uint8(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_scalar_ge(ct, pt, &bool_result); + const int r = fhe_uint8_scalar_ge(ct, pt, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* scalar_ge_fhe_uint16(void* ct, uint16_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_scalar_ge(ct, pt, &bool_result); + const int r = fhe_uint16_scalar_ge(ct, pt, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* scalar_ge_fhe_uint32(void* ct, uint32_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_scalar_ge(ct, pt, &bool_result); + const int r = fhe_uint32_scalar_ge(ct, pt, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* scalar_ge_fhe_uint64(void* ct, uint64_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_scalar_ge(ct, pt, &bool_result); + const int r = fhe_uint64_scalar_ge(ct, pt, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_32(bool_result, sks); return result; } void* gt_fhe_uint4(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_gt(ct1, ct2, &bool_result); + const int r = fhe_uint4_gt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* gt_fhe_uint8(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_gt(ct1, ct2, &bool_result); + const int r = fhe_uint8_gt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* gt_fhe_uint16(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_gt(ct1, ct2, &bool_result); + const int r = fhe_uint16_gt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* gt_fhe_uint32(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_gt(ct1, ct2, &bool_result); + const int r = fhe_uint32_gt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* gt_fhe_uint64(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_gt(ct1, ct2, &bool_result); + const int r = fhe_uint64_gt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* scalar_gt_fhe_uint4(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_scalar_gt(ct, pt, &bool_result); + const int r = fhe_uint4_scalar_gt(ct, pt, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* scalar_gt_fhe_uint8(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_scalar_gt(ct, pt, &bool_result); + const int r = fhe_uint8_scalar_gt(ct, pt, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* scalar_gt_fhe_uint16(void* ct, uint16_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_scalar_gt(ct, pt, &bool_result); + const int r = fhe_uint16_scalar_gt(ct, pt, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* scalar_gt_fhe_uint32(void* ct, uint32_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_scalar_gt(ct, pt, &bool_result); + const int r = fhe_uint32_scalar_gt(ct, pt, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* scalar_gt_fhe_uint64(void* ct, uint64_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_scalar_gt(ct, pt, &bool_result); + const int r = fhe_uint64_scalar_gt(ct, pt, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* le_fhe_uint4(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_le(ct1, ct2, &bool_result); + const int r = fhe_uint4_le(ct1, ct2, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* le_fhe_uint8(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_le(ct1, ct2, &bool_result); + const int r = fhe_uint4_le(ct1, ct2, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* le_fhe_uint16(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_le(ct1, ct2, &bool_result); + const int r = fhe_uint16_le(ct1, ct2, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* le_fhe_uint32(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_le(ct1, ct2, &bool_result); + const int r = fhe_uint32_le(ct1, ct2, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* le_fhe_uint64(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_le(ct1, ct2, &bool_result); + const int r = fhe_uint64_le(ct1, ct2, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* scalar_le_fhe_uint4(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_scalar_le(ct, pt, &bool_result); + const int r = fhe_uint4_scalar_le(ct, pt, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* scalar_le_fhe_uint8(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_scalar_le(ct, pt, &bool_result); + const int r = fhe_uint8_scalar_le(ct, pt, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* scalar_le_fhe_uint16(void* ct, uint16_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_scalar_le(ct, pt, &bool_result); + const int r = fhe_uint16_scalar_le(ct, pt, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* scalar_le_fhe_uint32(void* ct, uint32_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_scalar_le(ct, pt, &bool_result); + const int r = fhe_uint32_scalar_le(ct, pt, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* scalar_le_fhe_uint64(void* ct, uint64_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_scalar_le(ct, pt, &bool_result); + const int r = fhe_uint64_scalar_le(ct, pt, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* lt_fhe_uint4(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_lt(ct1, ct2, &bool_result); + const int r = fhe_uint4_lt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* lt_fhe_uint8(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_lt(ct1, ct2, &bool_result); + const int r = fhe_uint8_lt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* lt_fhe_uint16(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_lt(ct1, ct2, &bool_result); + const int r = fhe_uint16_lt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* lt_fhe_uint32(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_lt(ct1, ct2, &bool_result); + const int r = fhe_uint32_lt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* lt_fhe_uint64(void* ct1, void* ct2, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_lt(ct1, ct2, &bool_result); + const int r = fhe_uint64_lt(ct1, ct2, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } void* scalar_lt_fhe_uint4(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint4_scalar_lt(ct, pt, &bool_result); + const int r = fhe_uint4_scalar_lt(ct, pt, &result); if(r != 0) return NULL; - FheUint4* result = cast_bool_4(bool_result, sks); return result; } void* scalar_lt_fhe_uint8(void* ct, uint8_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint8_scalar_lt(ct, pt, &bool_result); + const int r = fhe_uint8_scalar_lt(ct, pt, &result); if(r != 0) return NULL; - FheUint8* result = cast_bool_8(bool_result, sks); return result; } void* scalar_lt_fhe_uint16(void* ct, uint16_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint16_scalar_lt(ct, pt, &bool_result); + const int r = fhe_uint16_scalar_lt(ct, pt, &result); if(r != 0) return NULL; - FheUint16* result = cast_bool_16(bool_result, sks); return result; } void* scalar_lt_fhe_uint32(void* ct, uint32_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint32_scalar_lt(ct, pt, &bool_result); + const int r = fhe_uint32_scalar_lt(ct, pt, &result); if(r != 0) return NULL; - FheUint32* result = cast_bool_32(bool_result, sks); return result; } void* scalar_lt_fhe_uint64(void* ct, uint64_t pt, void* sks) { - FheBool* bool_result = NULL; + FheBool* result = NULL; checked_set_server_key(sks); - const int r = fhe_uint64_scalar_lt(ct, pt, &bool_result); + const int r = fhe_uint64_scalar_lt(ct, pt, &result); if(r != 0) return NULL; - FheUint64* result = cast_bool_64(bool_result, sks); return result; } diff --git a/proto/kms.proto b/proto/kms.proto index 79a8ce1..dff572e 100644 --- a/proto/kms.proto +++ b/proto/kms.proto @@ -12,10 +12,11 @@ service KmsEndpoint { enum FheType { Bool = 0; - Euint8 = 1; - Euint16 = 2; - Euint32 = 3; - Euint64 = 4; + Euint4 = 1; + Euint8 = 2; + Euint16 = 3; + Euint32 = 4; + Euint64 = 5; } message Proof { From 1e34c3dfb38b465d8a8915ba8a9589aaa8d10a8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Tue, 27 Feb 2024 19:30:14 +0100 Subject: [PATCH 4/9] fix: update proto --- kms/kms.pb.go | 53 ++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/kms/kms.pb.go b/kms/kms.pb.go index a294afb..2fa8c80 100644 --- a/kms/kms.pb.go +++ b/kms/kms.pb.go @@ -437,32 +437,33 @@ var file_kms_proto_rawDesc = []byte{ 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x12, 0x27, 0x0a, 0x08, 0x66, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0c, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x07, 0x66, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x2a, 0x46, 0x0a, 0x07, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x08, 0x0a, 0x04, 0x42, - 0x6f, 0x6f, 0x6c, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x38, 0x10, - 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x31, 0x36, 0x10, 0x02, 0x12, 0x0b, - 0x0a, 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x10, 0x03, 0x12, 0x0b, 0x0a, 0x07, 0x45, - 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x10, 0x04, 0x32, 0xa3, 0x02, 0x0a, 0x0b, 0x4b, 0x6d, 0x73, - 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x47, 0x0a, 0x14, 0x56, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x6e, 0x64, 0x5f, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x12, 0x16, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, - 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x4d, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x6e, - 0x64, 0x5f, 0x72, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, - 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x3a, 0x0a, 0x07, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, 0x6d, - 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x40, 0x0a, 0x09, - 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, - 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x21, - 0x5a, 0x1f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x61, 0x6d, - 0x61, 0x2d, 0x61, 0x69, 0x2f, 0x66, 0x68, 0x65, 0x76, 0x6d, 0x2d, 0x67, 0x6f, 0x2f, 0x6b, 0x6d, - 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x2a, 0x52, 0x0a, 0x07, 0x46, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x08, 0x0a, 0x04, 0x42, + 0x6f, 0x6f, 0x6c, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x34, 0x10, + 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x38, 0x10, 0x02, 0x12, 0x0b, 0x0a, + 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, 0x31, 0x36, 0x10, 0x03, 0x12, 0x0b, 0x0a, 0x07, 0x45, 0x75, + 0x69, 0x6e, 0x74, 0x33, 0x32, 0x10, 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x45, 0x75, 0x69, 0x6e, 0x74, + 0x36, 0x34, 0x10, 0x05, 0x32, 0xa3, 0x02, 0x0a, 0x0b, 0x4b, 0x6d, 0x73, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x47, 0x0a, 0x14, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, + 0x5f, 0x61, 0x6e, 0x64, 0x5f, 0x64, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, + 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4d, 0x0a, + 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x6e, 0x64, 0x5f, 0x72, 0x65, + 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, + 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x19, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x07, + 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x16, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x17, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x44, 0x65, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x40, 0x0a, 0x09, 0x52, 0x65, 0x65, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x12, 0x18, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x19, 0x2e, 0x6b, 0x6d, 0x73, 0x2e, 0x52, 0x65, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x21, 0x5a, 0x1f, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x61, 0x6d, 0x61, 0x2d, 0x61, 0x69, + 0x2f, 0x66, 0x68, 0x65, 0x76, 0x6d, 0x2d, 0x67, 0x6f, 0x2f, 0x6b, 0x6d, 0x73, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( From 500e3e80e3e5c07d220c6c3541c4d09c0503f517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Wed, 28 Feb 2024 00:22:27 +0100 Subject: [PATCH 5/9] fix: some fix on type in C --- fhevm/contracts_test.go | 17 +++++++++++++++++ fhevm/tfhe_ciphertext.go | 8 ++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 1bded17..5634df0 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -241,6 +241,8 @@ func VerifyCiphertextBadType(t *testing.T, actualType FheUintType, metadataType func TrivialEncrypt(t *testing.T, fheUintType FheUintType) { var value big.Int switch fheUintType { + case FheBool: + value = *big.NewInt(1) case FheUint4: value = *big.NewInt(2) case FheUint8: @@ -1919,6 +1921,9 @@ func FheRem(t *testing.T, fheUintType FheUintType, scalar bool) { func FheBitAnd(t *testing.T, fheUintType FheUintType, scalar bool) { var lhs, rhs uint64 switch fheUintType { + case FheBool: + lhs = 1 + rhs = 0 case FheUint4: lhs = 2 rhs = 1 @@ -3402,6 +3407,10 @@ func TestFheScalarRem64(t *testing.T) { FheRem(t, FheUint64, true) } +func TestFheBitAndBool(t *testing.T) { + FheBitAnd(t, FheBool, false) +} + func TestFheBitAnd8(t *testing.T) { FheBitAnd(t, FheUint8, false) } @@ -3694,6 +3703,10 @@ func TestFheScalarGt64(t *testing.T) { FheGt(t, FheUint64, true) } +func TestFheLe4(t *testing.T) { + FheLe(t, FheUint4, false) +} + func TestFheLe8(t *testing.T) { FheLe(t, FheUint8, false) } @@ -3710,6 +3723,10 @@ func TestFheLe64(t *testing.T) { FheLe(t, FheUint64, false) } +func TestFheScalarLe4(t *testing.T) { + FheLe(t, FheUint4, true) +} + func TestFheScalarLe8(t *testing.T) { FheLe(t, FheUint8, true) } diff --git a/fhevm/tfhe_ciphertext.go b/fhevm/tfhe_ciphertext.go index fab0755..1dad99e 100644 --- a/fhevm/tfhe_ciphertext.go +++ b/fhevm/tfhe_ciphertext.go @@ -27,6 +27,10 @@ const ( func (t FheUintType) String() string { switch t { + case FheBool: + return "fheBool" + case FheUint4: + return "fheUint4" case FheUint8: return "fheUint8" case FheUint16: @@ -635,7 +639,7 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, ret := C.serialize_fhe_uint32(res_ptr, res_ser) C.destroy_fhe_uint32(res_ptr) if ret != 0 { - return nil, errors.New("8 bit binary op serialization failed") + return nil, errors.New("32 bit binary op serialization failed") } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) @@ -669,7 +673,7 @@ func (lhs *TfheCiphertext) executeBinaryCiphertextOperation(rhs *TfheCiphertext, ret := C.serialize_fhe_uint64(res_ptr, res_ser) C.destroy_fhe_uint64(res_ptr) if ret != 0 { - return nil, errors.New("8 bit binary op serialization failed") + return nil, errors.New("64 bit binary op serialization failed") } } res.serialization = C.GoBytes(unsafe.Pointer(res_ser.pointer), C.int(res_ser.length)) From 8f3983047b3120eb42204ccc01934f60bb0f0479 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Wed, 28 Feb 2024 01:31:29 +0100 Subject: [PATCH 6/9] fix: return correct type for eq/le/gt/... when in gas estimation --- fhevm/fhelib_run.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fhevm/fhelib_run.go b/fhevm/fhelib_run.go index 99d2575..3b4de55 100644 --- a/fhevm/fhelib_run.go +++ b/fhevm/fhelib_run.go @@ -246,7 +246,7 @@ func fheLeRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.Le(rhs.ciphertext) @@ -312,7 +312,7 @@ func fheLtRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.Lt(rhs.ciphertext) @@ -378,7 +378,7 @@ func fheEqRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.Eq(rhs.ciphertext) @@ -444,7 +444,7 @@ func fheGeRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.Ge(rhs.ciphertext) @@ -510,7 +510,7 @@ func fheGtRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.Gt(rhs.ciphertext) @@ -708,7 +708,7 @@ func fheNeRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.Ne(rhs.ciphertext) From 35bbf84864451c9f793acf4f297be8741d53aa2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Wed, 28 Feb 2024 01:54:14 +0100 Subject: [PATCH 7/9] feat: add not for ebool --- fhevm/contracts_test.go | 7 +++++++ fhevm/tfhe_ciphertext.go | 4 +++- fhevm/tfhe_wrappers.h | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fhevm/contracts_test.go b/fhevm/contracts_test.go index 5634df0..9e7fc84 100644 --- a/fhevm/contracts_test.go +++ b/fhevm/contracts_test.go @@ -1063,6 +1063,9 @@ func FheLibNeg(t *testing.T, fheUintType FheUintType) { func FheLibNot(t *testing.T, fheUintType FheUintType) { var pt, expected uint64 switch fheUintType { + case FheBool: + pt = 1 + expected = 0 case FheUint4: pt = 5 expected = uint64(15 - uint8(pt)) @@ -3078,6 +3081,10 @@ func TestFheLibNeg4(t *testing.T) { FheLibNeg(t, FheUint4) } +func TestFheLibNotBool(t *testing.T) { + FheLibNot(t, FheBool) +} + func TestFheLibNot4(t *testing.T) { FheLibNot(t, FheUint4) } diff --git a/fhevm/tfhe_ciphertext.go b/fhevm/tfhe_ciphertext.go index 1dad99e..18f54d7 100644 --- a/fhevm/tfhe_ciphertext.go +++ b/fhevm/tfhe_ciphertext.go @@ -1689,7 +1689,9 @@ func (lhs *TfheCiphertext) Neg() (*TfheCiphertext, error) { func (lhs *TfheCiphertext) Not() (*TfheCiphertext, error) { return lhs.executeUnaryCiphertextOperation(lhs, - boolUnaryNotSupportedOp, + func(lhs unsafe.Pointer) (unsafe.Pointer, error) { + return C.not_fhe_bool(lhs, sks), nil + }, func(lhs unsafe.Pointer) (unsafe.Pointer, error) { return C.not_fhe_uint4(lhs, sks), nil }, diff --git a/fhevm/tfhe_wrappers.h b/fhevm/tfhe_wrappers.h index 78588a5..3fa8d16 100644 --- a/fhevm/tfhe_wrappers.h +++ b/fhevm/tfhe_wrappers.h @@ -393,6 +393,8 @@ void* neg_fhe_uint32(void* ct, void* sks); void* neg_fhe_uint64(void* ct, void* sks); +void* not_fhe_bool(void* ct, void* sks); + void* not_fhe_uint4(void* ct, void* sks); void* not_fhe_uint8(void* ct, void* sks); From 6fca386e3c26474bcad47c67ad8c842996eb58eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Wed, 28 Feb 2024 18:48:59 +0100 Subject: [PATCH 8/9] fix: fix return for comparison on gas estimation --- fhevm/fhelib_run.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fhevm/fhelib_run.go b/fhevm/fhelib_run.go index 3b4de55..fa67f2b 100644 --- a/fhevm/fhelib_run.go +++ b/fhevm/fhelib_run.go @@ -270,7 +270,7 @@ func fheLeRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.ScalarLe(rhs.Uint64()) @@ -336,7 +336,7 @@ func fheLtRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.ScalarLt(rhs.Uint64()) @@ -402,7 +402,7 @@ func fheEqRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.ScalarEq(rhs.Uint64()) @@ -468,7 +468,7 @@ func fheGeRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.ScalarGe(rhs.Uint64()) @@ -534,7 +534,7 @@ func fheGtRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.ScalarGt(rhs.Uint64()) @@ -732,7 +732,7 @@ func fheNeRun(environment EVMEnvironment, caller common.Address, addr common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !environment.IsCommitting() && !environment.IsEthCall() { - return importRandomCiphertext(environment, lhs.fheUintType()), nil + return importRandomCiphertext(environment, FheBool), nil } result, err := lhs.ciphertext.ScalarNe(rhs.Uint64()) From 2f7fc42fa5568a28a6d193d631369e71efe542eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20=27birdy=27=20Danjou?= Date: Thu, 29 Feb 2024 17:51:00 +0100 Subject: [PATCH 9/9] fix: update gas price --- fhevm/params.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fhevm/params.go b/fhevm/params.go index 47e9e70..524edeb 100644 --- a/fhevm/params.go +++ b/fhevm/params.go @@ -67,7 +67,7 @@ type GasCosts struct { func DefaultGasCosts() GasCosts { return GasCosts{ FheAddSub: map[FheUintType]uint64{ - FheUint4: 84000 + AdjustFHEGas, + FheUint4: 60000 + AdjustFHEGas, FheUint8: 84000 + AdjustFHEGas, FheUint16: 123000 + AdjustFHEGas, FheUint32: 152000 + AdjustFHEGas, @@ -81,43 +81,43 @@ func DefaultGasCosts() GasCosts { FheUint64: 500000, }, FheBitwiseOp: map[FheUintType]uint64{ - FheBool: 23000 + AdjustFHEGas, - FheUint4: 24000 + AdjustFHEGas, + FheBool: 16000 + AdjustFHEGas, + FheUint4: 23000 + AdjustFHEGas, FheUint8: 24000 + AdjustFHEGas, FheUint16: 24000 + AdjustFHEGas, FheUint32: 25000 + AdjustFHEGas, FheUint64: 28000 + AdjustFHEGas, }, FheMul: map[FheUintType]uint64{ - FheUint4: 187000 + AdjustFHEGas, + FheUint4: 140000 + AdjustFHEGas, FheUint8: 187000 + AdjustFHEGas, FheUint16: 252000 + AdjustFHEGas, FheUint32: 349000 + AdjustFHEGas, FheUint64: 631000 + AdjustFHEGas, }, FheScalarMul: map[FheUintType]uint64{ - FheUint4: 149000 + AdjustFHEGas, + FheUint4: 110000 + AdjustFHEGas, FheUint8: 149000 + AdjustFHEGas, FheUint16: 198000 + AdjustFHEGas, FheUint32: 254000 + AdjustFHEGas, FheUint64: 346000 + AdjustFHEGas, }, FheScalarDiv: map[FheUintType]uint64{ - FheUint4: 228000 + AdjustFHEGas, + FheUint4: 120000 + AdjustFHEGas, FheUint8: 228000 + AdjustFHEGas, FheUint16: 304000 + AdjustFHEGas, FheUint32: 388000 + AdjustFHEGas, FheUint64: 574000 + AdjustFHEGas, }, FheScalarRem: map[FheUintType]uint64{ - FheUint4: 450000 + AdjustFHEGas, + FheUint4: 250000 + AdjustFHEGas, FheUint8: 450000 + AdjustFHEGas, FheUint16: 612000 + AdjustFHEGas, FheUint32: 795000 + AdjustFHEGas, FheUint64: 1095000 + AdjustFHEGas, }, FheShift: map[FheUintType]uint64{ - FheUint4: 123000 + AdjustFHEGas, + FheUint4: 110000 + AdjustFHEGas, FheUint8: 123000 + AdjustFHEGas, FheUint16: 143000 + AdjustFHEGas, FheUint32: 173000 + AdjustFHEGas, @@ -138,14 +138,14 @@ func DefaultGasCosts() GasCosts { FheUint64: 76000 + AdjustFHEGas, }, FheMinMax: map[FheUintType]uint64{ - FheUint4: 94000 + AdjustFHEGas, + FheUint4: 50000 + AdjustFHEGas, FheUint8: 94000 + AdjustFHEGas, FheUint16: 120000 + AdjustFHEGas, FheUint32: 148000 + AdjustFHEGas, FheUint64: 189000 + AdjustFHEGas, }, FheScalarMinMax: map[FheUintType]uint64{ - FheUint4: 114000 + AdjustFHEGas, + FheUint4: 80000 + AdjustFHEGas, FheUint8: 114000 + AdjustFHEGas, FheUint16: 140000 + AdjustFHEGas, FheUint32: 154000 + AdjustFHEGas, @@ -159,7 +159,7 @@ func DefaultGasCosts() GasCosts { FheUint64: 27000 + AdjustFHEGas, }, FheNeg: map[FheUintType]uint64{ - FheUint4: 79000 + AdjustFHEGas, + FheUint4: 50000 + AdjustFHEGas, FheUint8: 79000 + AdjustFHEGas, FheUint16: 114000 + AdjustFHEGas, FheUint32: 150000 + AdjustFHEGas,