Skip to content

Commit

Permalink
fix: fix ifthenelse types
Browse files Browse the repository at this point in the history
  • Loading branch information
immortal-tofu committed Feb 27, 2024
1 parent 9f8901b commit 183683b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 36 deletions.
9 changes: 7 additions & 2 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
environment.depth = depth
addr := common.Address{}
readOnly := false
firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheUint8).GetHash()
firstHash := verifyCiphertextInTestMemory(environment, condition, depth, FheBool).GetHash()
secondHash := verifyCiphertextInTestMemory(environment, second, depth, fheUintType).GetHash()
thirdHash := verifyCiphertextInTestMemory(environment, third, depth, fheUintType).GetHash()
input := toLibPrecompileInputNoScalar(signature, firstHash, secondHash, thirdHash)
Expand Down Expand Up @@ -2772,7 +2772,7 @@ func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
environment.depth = depth
addr := common.Address{}
readOnly := false
conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, fheUintType).GetHash()
conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, FheBool).GetHash()
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).GetHash()
rhsHash := verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).GetHash()

Expand Down Expand Up @@ -3840,6 +3840,11 @@ func TestFheNot64(t *testing.T) {
FheNot(t, FheUint64, false)
}

func TestFheIfThenElse4(t *testing.T) {
FheIfThenElse(t, FheUint4, 1)
FheIfThenElse(t, FheUint4, 0)
}

func TestFheIfThenElse8(t *testing.T) {
FheIfThenElse(t, FheUint8, 1)
FheIfThenElse(t, FheUint8, 0)
Expand Down
2 changes: 1 addition & 1 deletion fhevm/fhelib_required_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ func fheIfThenElseRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("IfThenElse op RequiredGas() inputs not verified", "err", err, "input", hex.EncodeToString(input))
return 0
}
if first.fheUintType() != FheUint8 {
if first.fheUintType() != FheBool {
logger.Error("IfThenElse op RequiredGas() invalid type for condition", "first", first.fheUintType())
return 0
}
Expand Down
32 changes: 16 additions & 16 deletions fhevm/tfhe_ciphertext.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte
C.destroy_fhe_uint4(lhs_ptr)
return nil, errors.New("4 bit binary op deserialization failed")
}
first_ptr := C.deserialize_fhe_uint4(toDynamicBufferView((first.serialization)))
first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization)))
if first_ptr == nil {
C.destroy_fhe_uint4(lhs_ptr)
C.destroy_fhe_uint4(rhs_ptr)
return nil, errors.New("8 bit binary op deserialization failed")
return nil, errors.New("Bool binary op deserialization failed")
}
res_ptr := op4(first_ptr, lhs_ptr, rhs_ptr)
C.destroy_fhe_uint4(lhs_ptr)
Expand All @@ -688,11 +688,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte
C.destroy_fhe_uint8(lhs_ptr)
return nil, errors.New("8 bit binary op deserialization failed")
}
first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization)))
first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization)))
if first_ptr == nil {
C.destroy_fhe_uint8(lhs_ptr)
C.destroy_fhe_uint8(rhs_ptr)
return nil, errors.New("8 bit binary op deserialization failed")
return nil, errors.New("Bool binary op deserialization failed")
}
res_ptr := op8(first_ptr, lhs_ptr, rhs_ptr)
C.destroy_fhe_uint8(lhs_ptr)
Expand All @@ -717,11 +717,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte
C.destroy_fhe_uint16(lhs_ptr)
return nil, errors.New("16 bit binary op deserialization failed")
}
first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization)))
first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization)))
if first_ptr == nil {
C.destroy_fhe_uint8(lhs_ptr)
C.destroy_fhe_uint8(rhs_ptr)
return nil, errors.New("8 bit binary op deserialization failed")
C.destroy_fhe_uint16(lhs_ptr)
C.destroy_fhe_uint16(rhs_ptr)
return nil, errors.New("Bool binary op deserialization failed")
}
res_ptr := op16(first_ptr, lhs_ptr, rhs_ptr)
C.destroy_fhe_uint16(lhs_ptr)
Expand All @@ -746,11 +746,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte
C.destroy_fhe_uint32(lhs_ptr)
return nil, errors.New("32 bit binary op deserialization failed")
}
first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization)))
first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization)))
if first_ptr == nil {
C.destroy_fhe_uint8(lhs_ptr)
C.destroy_fhe_uint8(rhs_ptr)
return nil, errors.New("8 bit binary op deserialization failed")
C.destroy_fhe_uint32(lhs_ptr)
C.destroy_fhe_uint32(rhs_ptr)
return nil, errors.New("Bool binary op deserialization failed")
}
res_ptr := op32(first_ptr, lhs_ptr, rhs_ptr)
C.destroy_fhe_uint32(lhs_ptr)
Expand All @@ -775,11 +775,11 @@ func (first *TfheCiphertext) executeTernaryCiphertextOperation(lhs *TfheCipherte
C.destroy_fhe_uint64(lhs_ptr)
return nil, errors.New("64 bit binary op deserialization failed")
}
first_ptr := C.deserialize_fhe_uint8(toDynamicBufferView((first.serialization)))
first_ptr := C.deserialize_fhe_bool(toDynamicBufferView((first.serialization)))
if first_ptr == nil {
C.destroy_fhe_uint8(lhs_ptr)
C.destroy_fhe_uint8(rhs_ptr)
return nil, errors.New("8 bit binary op deserialization failed")
C.destroy_fhe_uint64(lhs_ptr)
C.destroy_fhe_uint64(rhs_ptr)
return nil, errors.New("Bool binary op deserialization failed")
}
res_ptr := op64(first_ptr, lhs_ptr, rhs_ptr)
C.destroy_fhe_uint64(lhs_ptr)
Expand Down
4 changes: 2 additions & 2 deletions fhevm/tfhe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1276,9 +1276,9 @@ func TfheIfThenElse(t *testing.T, fheUintType FheUintType) {
b.SetUint64(133337)
}
ctCondition := new(TfheCiphertext)
ctCondition.Encrypt(condition, fheUintType)
ctCondition.Encrypt(condition, FheBool)
ctCondition2 := new(TfheCiphertext)
ctCondition2.Encrypt(condition2, fheUintType)
ctCondition2.Encrypt(condition2, FheBool)
ctA := new(TfheCiphertext)
ctA.Encrypt(a, fheUintType)
ctB := new(TfheCiphertext)
Expand Down
20 changes: 5 additions & 15 deletions fhevm/tfhe_wrappers.c
Original file line number Diff line number Diff line change
Expand Up @@ -2295,9 +2295,7 @@ void* if_then_else_fhe_uint4(void* condition, void* ct1, void* ct2, void* sks)

checked_set_server_key(sks);

FheBool* cond = cast_4_bool(condition, sks);

const int r = fhe_uint4_if_then_else(cond, ct1, ct2, &result);
const int r = fhe_uint4_if_then_else(condition, ct1, ct2, &result);
if(r != 0) return NULL;
return result;
}
Expand All @@ -2308,9 +2306,7 @@ void* if_then_else_fhe_uint8(void* condition, void* ct1, void* ct2, void* sks)

checked_set_server_key(sks);

FheBool* cond = cast_8_bool(condition, sks);

const int r = fhe_uint8_if_then_else(cond, ct1, ct2, &result);
const int r = fhe_uint8_if_then_else(condition, ct1, ct2, &result);
if(r != 0) return NULL;
return result;
}
Expand All @@ -2321,9 +2317,7 @@ void* if_then_else_fhe_uint16(void* condition, void* ct1, void* ct2, void* sks)

checked_set_server_key(sks);

FheBool* cond = cast_8_bool(condition, sks);

const int r = fhe_uint16_if_then_else(cond, ct1, ct2, &result);
const int r = fhe_uint16_if_then_else(condition, ct1, ct2, &result);
if(r != 0) return NULL;
return result;
}
Expand All @@ -2334,9 +2328,7 @@ void* if_then_else_fhe_uint32(void* condition, void* ct1, void* ct2, void* sks)

checked_set_server_key(sks);

FheBool* cond = cast_8_bool(condition, sks);

const int r = fhe_uint32_if_then_else(cond, ct1, ct2, &result);
const int r = fhe_uint32_if_then_else(condition, ct1, ct2, &result);
if(r != 0) return NULL;
return result;
}
Expand All @@ -2347,9 +2339,7 @@ void* if_then_else_fhe_uint64(void* condition, void* ct1, void* ct2, void* sks)

checked_set_server_key(sks);

FheBool* cond = cast_8_bool(condition, sks);

const int r = fhe_uint64_if_then_else(cond, ct1, ct2, &result);
const int r = fhe_uint64_if_then_else(condition, ct1, ct2, &result);
if(r != 0) return NULL;
return result;
}
Expand Down

0 comments on commit 183683b

Please sign in to comment.