Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add paillier algorithm support with denglin's gpu card #112

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
131 changes: 57 additions & 74 deletions heu/library/algorithms/paillier_dl/cgbn_wrapper/cgbn_wrapper.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@ typedef typename env_t::cgbn_t bn_t;
typedef typename env_t::cgbn_local_t bn_local_t;
typedef cgbn_mem_t<BITS> gpu_mpz;

static void p_mpint(char *name, MPInt *d) {
printf("[%s]\n", name);
for (int i=0; i<(d->SizeUsed() + 3) / 4; i++) {
printf("%08x ", ((uint32_t *)(d->n_.dp))[i]);
}
printf("\n");
}

static __device__ void p_cgbn(char *name, cgbn_mem_t<BITS> *d) {
printf("[%s]\n", name);
for (int i=0; i<(sizeof(d->_limbs) + 3) / 4; i++) {
Expand All @@ -51,24 +43,13 @@ static void buf_cal_used(mp_digit *buf, int size, int *used) {
*used = count;
}

static void mpint_cal_used(MPInt* out) {
int used = 0;
for (int i=0; i<out->n_.alloc; i++) {
if (out->n_.dp[i] != 0) {
used = i + 1;
}
}
out->n_.used = used;
}

static void store2dev(dev_mem_t<BITS> *address, MPInt &z, bool handle_neg = true) {
int32_t z_size = z.SizeUsed();
if (z_size > sizeof(address->_limbs)) {
printf("%s:%d No enough memory, need: %d, real: %d\n", __FILE__, __LINE__, z_size, sizeof(address->_limbs));
static void store2dev(dev_mem_t<BITS> *address, MPInt &z) {
auto buffer = z.ToMagBytes(Endian::little);
if (buffer.size() > sizeof(address->_limbs)) {
printf("%s:%d No enough memory, need: %d, real: %d\n", __FILE__, __LINE__, buffer.size(), sizeof(address->_limbs));
abort();
}

auto buffer = z.ToMagBytes(Endian::little);

CUDA_CHECK(cudaMemset(address->_limbs, 0, sizeof(address->_limbs)));
CUDA_CHECK(cudaMemcpy(address->_limbs, buffer.data(), buffer.size(), cudaMemcpyHostToDevice));
}
Expand All @@ -81,7 +62,7 @@ static void store2dev(void *address, SecretKey *sk) {
CUDA_CHECK(cudaMemcpy(address, sk, sizeof(SecretKey), cudaMemcpyHostToDevice));
}

static void store2host(MPInt &z, dev_mem_t<BITS> *address, bool handle_neg = true) {
static void store2host(MPInt &z, dev_mem_t<BITS> *address) {
int32_t z_size = sizeof(address->_limbs);

yacl::Buffer buffer(z_size);
Expand All @@ -90,12 +71,10 @@ static void store2host(MPInt &z, dev_mem_t<BITS> *address, bool handle_neg = tru

int used = 0;
buf_cal_used((mp_digit *)buffer.data(), z_size / sizeof(mp_digit), &used);
mp_sign sign = z.n_.sign;

buffer.resize(used * sizeof(mp_digit));
Endian endian = Endian::little;
z.FromMagBytes(buffer, endian);
z.n_.sign = sign;
}

__device__ __forceinline__ void powmod(env_t &bn_env, env_t::cgbn_t &r, env_t::cgbn_t &a, env_t::cgbn_t &b, env_t::cgbn_t &c) {
Expand Down Expand Up @@ -280,7 +259,7 @@ __global__ __noinline__ void raw_encrypt(PublicKey *pub_key, cgbn_error_report_t
#endif
}

void CGBNWrapper::Encrypt(const std::vector<Plaintext> pts, const PublicKey pk, std::vector<MPInt> &rns, std::vector<Ciphertext> &cts) {
void CGBNWrapper::Encrypt(const std::vector<Plaintext>& pts, const PublicKey& pk, std::vector<MPInt>& rns, std::vector<Ciphertext>& cts) {
int32_t TPB=128;
int32_t IPB=TPB/TPI;
int count = pts.size();
Expand All @@ -296,7 +275,7 @@ void CGBNWrapper::Encrypt(const std::vector<Plaintext> pts, const PublicKey pk,
CUDA_CHECK(cudaMemset(dev_ciphers->_limbs, 0, sizeof(cgbn_mem_t<BITS>) * count));

for (int i=0; i<count; i++) {
store2dev((dev_mem_t<BITS> *)(dev_plains + i), *const_cast<Plaintext *>(&pts[i]), false);
store2dev((dev_mem_t<BITS> *)(dev_plains + i), *const_cast<Plaintext *>(&pts[i]));
}

CUDA_CHECK(cgbn_error_report_alloc(&report));
Expand All @@ -305,7 +284,7 @@ void CGBNWrapper::Encrypt(const std::vector<Plaintext> pts, const PublicKey pk,
CUDA_CHECK(cudaDeviceSynchronize());

for (int i=0; i<count; i++) {
store2host(cts[i].c_, (dev_mem_t<BITS> *)(dev_ciphers + i), false);
store2host(cts[i].c_, (dev_mem_t<BITS> *)(dev_ciphers + i));
}

CGBN_CHECK(report);
Expand Down Expand Up @@ -369,7 +348,7 @@ __global__ void raw_decrypt(SecretKey *priv_key, dev_mem_t<BITS> *pk_n, cgbn_err
#endif
}

void CGBNWrapper::Decrypt(const std::vector<Ciphertext>& cts, const SecretKey sk, const PublicKey pk, std::vector<Plaintext>& pts) {
void CGBNWrapper::Decrypt(const std::vector<Ciphertext>& cts, const SecretKey& sk, const PublicKey& pk, std::vector<Plaintext>& pts) {
int32_t TPB=128;
int32_t IPB=TPB/TPI;
int count = cts.size();
Expand All @@ -386,7 +365,7 @@ void CGBNWrapper::Decrypt(const std::vector<Ciphertext>& cts, const SecretKey sk
CUDA_CHECK(cudaMemset(dev_ciphers->_limbs, 0, sizeof(cgbn_mem_t<BITS>) * count));

for (int i=0; i<count; i++) {
store2dev((dev_mem_t<BITS> *)(dev_ciphers + i), *const_cast<MPInt *>(&cts[i].c_), false);
store2dev((dev_mem_t<BITS> *)(dev_ciphers + i), *const_cast<MPInt *>(&cts[i].c_));
}

CUDA_CHECK(cgbn_error_report_alloc(&report));
Expand All @@ -396,7 +375,7 @@ void CGBNWrapper::Decrypt(const std::vector<Ciphertext>& cts, const SecretKey sk
CGBN_CHECK(report);

for (int i=0; i<count; i++) {
store2host(pts[i], (dev_mem_t<BITS> *)(dev_plains + i), false);
store2host(pts[i], (dev_mem_t<BITS> *)(dev_plains + i));
}

CUDA_CHECK(cgbn_error_report_free(report));
Expand Down Expand Up @@ -438,7 +417,7 @@ cgbn_mont2bn(bn_env, r, r, nsquare, np0);
#endif
}

void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, const std::vector<Ciphertext>& bs, std::vector<Ciphertext>& cs) {
void CGBNWrapper::Add(const PublicKey& pk, const std::vector<Ciphertext>& as, const std::vector<Ciphertext>& bs, std::vector<Ciphertext>& cs) {
int32_t TPB=128;
int32_t IPB=TPB/TPI;
int count = as.size();
Expand All @@ -457,8 +436,8 @@ void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, con
CUDA_CHECK(cudaMemset(dev_cs->_limbs, 0, sizeof(cgbn_mem_t<BITS>) * count));

for (int i=0; i<count; i++) {
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_), false);
store2dev((dev_mem_t<BITS> *)(dev_bs + i), *const_cast<MPInt *>(&bs[i].c_), false);
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_));
store2dev((dev_mem_t<BITS> *)(dev_bs + i), *const_cast<MPInt *>(&bs[i].c_));
}

CUDA_CHECK(cgbn_error_report_alloc(&report));
Expand All @@ -468,7 +447,7 @@ void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, con
CGBN_CHECK(report);

for (int i=0; i<count; i++) {
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i), false);
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i));
}

CUDA_CHECK(cgbn_error_report_free(report));
Expand All @@ -477,7 +456,7 @@ void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, con
CUDA_CHECK(cudaFree(dev_cs));
}

void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs) {
void CGBNWrapper::Add(const PublicKey& pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs) {
int32_t TPB=128;
int32_t IPB=TPB/TPI;
int count = as.size();
Expand All @@ -499,8 +478,8 @@ void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, con
CUDA_CHECK(cudaMemset(dev_cs->_limbs, 0, sizeof(cgbn_mem_t<BITS>) * count));

for (int i=0; i<count; i++) {
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_), false);
store2dev((dev_mem_t<BITS> *)(dev_bs + i), *const_cast<MPInt *>(&bs[i]), true);
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_));
store2dev((dev_mem_t<BITS> *)(dev_bs + i), *const_cast<MPInt *>(&bs[i]));
}

CUDA_CHECK(cgbn_error_report_alloc(&report));
Expand All @@ -510,7 +489,7 @@ void CGBNWrapper::Add(const PublicKey pk, const std::vector<Ciphertext>& as, con
CGBN_CHECK(report);

for (int i=0; i<count; i++) {
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i), false);
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i));
}

CUDA_CHECK(cgbn_error_report_free(report));
Expand Down Expand Up @@ -556,7 +535,7 @@ __global__ void raw_mul(dev_mem_t<BITS> *pk_n, dev_mem_t<BITS> *pk_max_int, dev_
#endif
}

void CGBNWrapper::Mul(const PublicKey pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs) {
void CGBNWrapper::Mul(const PublicKey& pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs) {
int32_t TPB=128;
int32_t IPB=TPB/TPI;
int count = as.size();
Expand All @@ -575,8 +554,8 @@ void CGBNWrapper::Mul(const PublicKey pk, const std::vector<Ciphertext>& as, con
CUDA_CHECK(cudaMemset(dev_cs->_limbs, 0, sizeof(cgbn_mem_t<BITS>) * count));

for (int i=0; i<count; i++) {
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_), false);
store2dev((dev_mem_t<BITS> *)(dev_bs + i), *const_cast<Plaintext *>(&bs[i]), false);
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_));
store2dev((dev_mem_t<BITS> *)(dev_bs + i), *const_cast<Plaintext *>(&bs[i]));
}

CUDA_CHECK(cgbn_error_report_alloc(&report));
Expand All @@ -586,7 +565,7 @@ void CGBNWrapper::Mul(const PublicKey pk, const std::vector<Ciphertext>& as, con
CGBN_CHECK(report);

for (int i=0; i<count; i++) {
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i), false);
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i));
}

CUDA_CHECK(cgbn_error_report_free(report));
Expand Down Expand Up @@ -616,7 +595,7 @@ __global__ void raw_negate(dev_mem_t<BITS> *pk_nsquare, cgbn_error_report_t *re
#endif
}

void CGBNWrapper::Negate(const PublicKey pk, const std::vector<Ciphertext>& as, std::vector<Ciphertext>& cs) {
void CGBNWrapper::Negate(const PublicKey& pk, const std::vector<Ciphertext>& as, std::vector<Ciphertext>& cs) {

int32_t TPB=128;
int32_t IPB=TPB/TPI;
Expand All @@ -633,7 +612,7 @@ void CGBNWrapper::Negate(const PublicKey pk, const std::vector<Ciphertext>& as,
CUDA_CHECK(cudaMemset(dev_cs->_limbs, 0, sizeof(cgbn_mem_t<BITS>) * count));

for (int i=0; i<count; i++) {
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_), false);
store2dev((dev_mem_t<BITS> *)(dev_as + i), *const_cast<MPInt *>(&as[i].c_));
}

CUDA_CHECK(cgbn_error_report_alloc(&report));
Expand All @@ -643,7 +622,7 @@ void CGBNWrapper::Negate(const PublicKey pk, const std::vector<Ciphertext>& as,
CGBN_CHECK(report);

for (int i=0; i<count; i++) {
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i), false);
store2host(cs[i].c_, (dev_mem_t<BITS> *)(dev_cs + i));
}

CUDA_CHECK(cgbn_error_report_free(report));
Expand All @@ -660,11 +639,13 @@ void CGBNWrapper::DevMalloc(PublicKey *pk) {
}

void CGBNWrapper::DevFree(PublicKey *pk) {
CUDA_CHECK(cudaFree(pk->dev_g_));
CUDA_CHECK(cudaFree(pk->dev_n_));
CUDA_CHECK(cudaFree(pk->dev_nsquare_));
CUDA_CHECK(cudaFree(pk->dev_max_int_));
CUDA_CHECK(cudaFree(pk->dev_pk_));
if(pk->dev_pk_){
CUDA_CHECK(cudaFree(pk->dev_g_));
CUDA_CHECK(cudaFree(pk->dev_n_));
CUDA_CHECK(cudaFree(pk->dev_nsquare_));
CUDA_CHECK(cudaFree(pk->dev_max_int_));
CUDA_CHECK(cudaFree(pk->dev_pk_));
}
}

void CGBNWrapper::DevMalloc(SecretKey *sk) {
Expand All @@ -680,15 +661,17 @@ void CGBNWrapper::DevMalloc(SecretKey *sk) {
}

void CGBNWrapper::DevFree(SecretKey *sk) {
CUDA_CHECK(cudaFree(sk->dev_g_));
CUDA_CHECK(cudaFree(sk->dev_p_));
CUDA_CHECK(cudaFree(sk->dev_q_));
CUDA_CHECK(cudaFree(sk->dev_psquare_));
CUDA_CHECK(cudaFree(sk->dev_qsquare_));
CUDA_CHECK(cudaFree(sk->dev_q_inverse_));
CUDA_CHECK(cudaFree(sk->dev_hp_));
CUDA_CHECK(cudaFree(sk->dev_hq_));
CUDA_CHECK(cudaFree(sk->dev_sk_));
if(sk->dev_sk_){
CUDA_CHECK(cudaFree(sk->dev_g_));
CUDA_CHECK(cudaFree(sk->dev_p_));
CUDA_CHECK(cudaFree(sk->dev_q_));
CUDA_CHECK(cudaFree(sk->dev_psquare_));
CUDA_CHECK(cudaFree(sk->dev_qsquare_));
CUDA_CHECK(cudaFree(sk->dev_q_inverse_));
CUDA_CHECK(cudaFree(sk->dev_hp_));
CUDA_CHECK(cudaFree(sk->dev_hq_));
CUDA_CHECK(cudaFree(sk->dev_sk_));
}
}


Expand All @@ -698,24 +681,24 @@ void CGBNWrapper::StoreToDev(PublicKey *pk) {
}

void CGBNWrapper::StoreToDev(SecretKey *sk) {
store2dev(sk->dev_g_, sk->g_, false);
store2dev(sk->dev_p_, sk->p_, false);
store2dev(sk->dev_q_, sk->q_, false);
store2dev(sk->dev_g_, sk->g_);
store2dev(sk->dev_p_, sk->p_);
store2dev(sk->dev_q_, sk->q_);
store2dev(sk->dev_sk_, sk);
}

void CGBNWrapper::StoreToHost(PublicKey *pk) {
store2host(pk->g_, pk->dev_g_, false);
store2host(pk->nsquare_, pk->dev_nsquare_, false);
store2host(pk->max_int_, pk->dev_max_int_, false);
store2host(pk->g_, pk->dev_g_);
store2host(pk->nsquare_, pk->dev_nsquare_);
store2host(pk->max_int_, pk->dev_max_int_);
}

void CGBNWrapper::StoreToHost(SecretKey *sk) {
store2host(sk->psquare_, sk->dev_psquare_, false);
store2host(sk->qsquare_, sk->dev_qsquare_, false);
store2host(sk->q_inverse_, sk->dev_q_inverse_, false);
store2host(sk->hp_, sk->dev_hp_, false);
store2host(sk->hq_, sk->dev_hq_, false);
store2host(sk->psquare_, sk->dev_psquare_);
store2host(sk->qsquare_, sk->dev_qsquare_);
store2host(sk->q_inverse_, sk->dev_q_inverse_);
store2host(sk->hp_, sk->dev_hp_);
store2host(sk->hq_, sk->dev_hq_);
}

} // namespace heu::lib::algorithms::paillier_dl
} // namespace heu::lib::algorithms::paillier_dl
12 changes: 6 additions & 6 deletions heu/library/algorithms/paillier_dl/cgbn_wrapper/cgbn_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class CGBNWrapper {
public:
static void InitSK(SecretKey *sk);
static void InitPK(PublicKey *pk);
static void Encrypt(const std::vector<Plaintext> pts, const PublicKey pk, std::vector<MPInt> &rns, std::vector<Ciphertext>& cts);
static void Decrypt(const std::vector<Ciphertext>& cts, const SecretKey sk, const PublicKey pk, std::vector<Plaintext>& pts);
static void Add(const PublicKey pk, const std::vector<Ciphertext>& as, const std::vector<Ciphertext>& bs, std::vector<Ciphertext>& cs);
static void Add(const PublicKey pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs);
static void Mul(const PublicKey pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs);
static void Negate(const PublicKey pk, const std::vector<Ciphertext>& as, std::vector<Ciphertext>& cs);
static void Encrypt(const std::vector<Plaintext>& pts, const PublicKey& pk, std::vector<MPInt>& rns, std::vector<Ciphertext>& cts);
static void Decrypt(const std::vector<Ciphertext>& cts, const SecretKey& sk, const PublicKey& pk, std::vector<Plaintext>& pts);
static void Add(const PublicKey& pk, const std::vector<Ciphertext>& as, const std::vector<Ciphertext>& bs, std::vector<Ciphertext>& cs);
static void Add(const PublicKey& pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs);
static void Mul(const PublicKey& pk, const std::vector<Ciphertext>& as, const std::vector<Plaintext>& bs, std::vector<Ciphertext>& cs);
static void Negate(const PublicKey& pk, const std::vector<Ciphertext>& as, std::vector<Ciphertext>& cs);
static void DevMalloc(PublicKey *pk);
static void DevMalloc(SecretKey *sk);
static void DevFree(PublicKey *pk);
Expand Down
6 changes: 0 additions & 6 deletions heu/library/algorithms/paillier_dl/evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ std::vector<Ciphertext> Evaluator::Sub(ConstSpan<Ciphertext> as, ConstSpan<Plain
for (int i=0; i<bs.size(); i++) {
Plaintext neg_b;
bs[i]->Negate(&neg_b);
if (neg_b.IsNegative()) {
neg_b += pk_.n_;
}
neg_bs_vec.emplace_back(neg_b);
}
std::vector<Plaintext *> neg_bs_pt;
Expand All @@ -119,9 +116,6 @@ void Evaluator::SubInplace(Span<Ciphertext> as, ConstSpan<Plaintext> bs) const {
for (int i=0; i<bs.size(); i++) {
Plaintext neg_b;
bs[i]->Negate(&neg_b);
if (neg_b.IsNegative()) {
neg_b += pk_.n_;
}
neg_bs_vec.emplace_back(neg_b);
}
std::vector<Plaintext *> neg_bs_pt;
Expand Down
2 changes: 1 addition & 1 deletion heu/library/algorithms/paillier_dl/key_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void KeyGenerator::Generate(size_t key_size, SecretKey* sk, PublicKey* pk) {
} while (n.BitCount() < key_size);

MPInt g;
pk->Init(n, g);
pk->Init(n, &g);
sk->Init(g, p, q);
}

Expand Down
Loading
Loading