Skip to content

Commit

Permalink
Merge pull request ethereum#145 from zama-ai/petar/fix-mem-leak
Browse files Browse the repository at this point in the history
Fix memleaks on TFHE encryption and compact deser
  • Loading branch information
dartdart26 authored Jul 13, 2023
2 parents 1c5fa4f + 025d18a commit 7227a73
Showing 1 changed file with 44 additions and 13 deletions.
57 changes: 44 additions & 13 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,9 @@ void public_key_encrypt_and_serialize_fhe_uint8_list(void* pks, uint8_t value, B
r = compact_fhe_uint8_list_serialize(list, out);
assert(r == 0);
r = compact_fhe_uint8_list_destroy(list);
assert(r == 0);
}
void public_key_encrypt_and_serialize_fhe_uint16_list(void* pks, uint16_t value, Buffer* out) {
Expand All @@ -1314,6 +1317,9 @@ void public_key_encrypt_and_serialize_fhe_uint16_list(void* pks, uint16_t value,
r = compact_fhe_uint16_list_serialize(list, out);
assert(r == 0);
r = compact_fhe_uint16_list_destroy(list);
assert(r == 0);
}
void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, Buffer* out) {
Expand All @@ -1324,6 +1330,9 @@ void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value,
r = compact_fhe_uint32_list_serialize(list, out);
assert(r == 0);
r = compact_fhe_uint32_list_destroy(list);
assert(r == 0);
}
void* cast_8_16(void* ct, void* sks) {
Expand Down Expand Up @@ -1548,32 +1557,32 @@ func (ct *tfheCiphertext) deserializeCompact(in []byte, t fheUintType) error {
}
var err error
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint8(ptr)
if err != nil {
return err
}
C.destroy_fhe_uint8(ptr)
case FheUint16:
ptr := C.deserialize_compact_fhe_uint16(toBufferView((in)))
if ptr == nil {
return errors.New("compact FheUint16 ciphertext deserialization failed")
}
var err error
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint16(ptr)
if err != nil {
return err
}
C.destroy_fhe_uint16(ptr)
case FheUint32:
ptr := C.deserialize_compact_fhe_uint32(toBufferView((in)))
if ptr == nil {
return errors.New("compact FheUint32 ciphertext deserialization failed")
}
var err error
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint32(ptr)
if err != nil {
return err
}
C.destroy_fhe_uint32(ptr)
default:
panic("deserializeCompact: unexpected ciphertext type")
}
Expand All @@ -1586,43 +1595,65 @@ func (ct *tfheCiphertext) deserializeCompact(in []byte, t fheUintType) error {
// The resulting ciphertext is automaticaly expanded.
func (ct *tfheCiphertext) encrypt(value big.Int, t fheUintType) *tfheCiphertext {
var ptr unsafe.Pointer
var err error
switch t {
case FheUint8:
ptr = C.public_key_encrypt_fhe_uint8(pks, C.uint8_t(value.Uint64()))
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint8(ptr)
if err != nil {
panic(err)
}
case FheUint16:
ptr = C.public_key_encrypt_fhe_uint16(pks, C.uint16_t(value.Uint64()))
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint16(ptr)
if err != nil {
panic(err)
}
case FheUint32:
ptr = C.public_key_encrypt_fhe_uint32(pks, C.uint32_t(value.Uint64()))
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint32(ptr)
if err != nil {
panic(err)
}
default:
panic("encrypt: unexpected ciphertext type")
}
var err error
ct.serialization, err = serialize(ptr, t)
if err != nil {
panic(err)
}
ct.fheUintType = t
ct.computeHash()
return ct
}

func (ct *tfheCiphertext) trivialEncrypt(value big.Int, t fheUintType) *tfheCiphertext {
var ptr unsafe.Pointer
var err error
switch t {
case FheUint8:
ptr = C.trivial_encrypt_fhe_uint8(sks, C.uint8_t(value.Uint64()))
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint8(ptr)
if err != nil {
panic(err)
}
case FheUint16:
ptr = C.trivial_encrypt_fhe_uint16(sks, C.uint16_t(value.Uint64()))
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint16(ptr)
if err != nil {
panic(err)
}
case FheUint32:
ptr = C.trivial_encrypt_fhe_uint32(sks, C.uint32_t(value.Uint64()))
ct.serialization, err = serialize(ptr, t)
C.destroy_fhe_uint32(ptr)
if err != nil {
panic(err)
}
default:
panic("trivialEncrypt: unexpected ciphertext type")
}
var err error
ct.serialization, err = serialize(ptr, ct.fheUintType)
if err != nil {
panic(err)
}
ct.fheUintType = t
ct.computeHash()
return ct
Expand Down

0 comments on commit 7227a73

Please sign in to comment.