Skip to content

Commit

Permalink
repo-sync-2024-08-02T16:42:42+0800 (secretflow#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaojian-ant authored Aug 2, 2024
1 parent 31d212f commit b6a08f1
Show file tree
Hide file tree
Showing 27 changed files with 1,917 additions and 4 deletions.
2 changes: 1 addition & 1 deletion heu/algorithms/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@yacl//bazel:yacl.bzl", "yacl_cc_library")
load("@bazel_skylib//lib:subpackages.bzl", "subpackages")
load("@yacl//bazel:yacl.bzl", "yacl_cc_library")

package(default_visibility = ["//visibility:public"])

Expand Down
21 changes: 19 additions & 2 deletions heu/algorithms/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@yacl//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test")
load("@yacl//bazel:yacl.bzl", "yacl_cc_library")

package(default_visibility = ["//visibility:public"])

yacl_cc_library(
name = "common",
deps = [
":he_assert",
":type_alias",
],
)

yacl_cc_library(
name = "type_alias",
hdrs = ["type_alias.h"],
deps = ["@yacl//yacl/math/mpint"],
deps = [
"@yacl//yacl/math/mpint",
"@yacl//yacl/math/mpint:montgomery_math",
],
)

yacl_cc_library(
name = "he_assert",
hdrs = ["he_assert.h"],
deps = ["@yacl//yacl/base:exception"],
)
23 changes: 23 additions & 0 deletions heu/algorithms/common/he_assert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 20244 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "yacl/base/exception.h"

#ifdef NDEBUG
#define HE_ASSERT(condition, ...) ((void)0)
#else
#define HE_ASSERT(condition, ...) YACL_ENFORCE(condition, __VA_ARGS__)
#endif
4 changes: 4 additions & 0 deletions heu/algorithms/common/type_alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

#pragma once

#include "yacl/math/mpint/montgomery_math.h"
#include "yacl/math/mpint/mp_int.h"

namespace heu::algos {

using yacl::math::MPInt;
using yacl::math::PrimeType;

using yacl::math::BaseTable;
using yacl::math::MontgomerySpace;

} // namespace heu::algos
67 changes: 67 additions & 0 deletions heu/algorithms/ou/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 Ant Group Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@yacl//bazel:yacl.bzl", "yacl_cc_library")

package(default_visibility = ["//visibility:public"])

yacl_cc_library(
name = "ou",
srcs = ["he_kit.cc"],
hdrs = ["he_kit.h"],
deps = [
":decryptor",
":encryptor",
":evaluator",
],
alwayslink = 1,
)

yacl_cc_library(
name = "base",
srcs = ["base.cc"],
hdrs = ["base.h"],
deps = [
"//heu/algorithms/common",
"//heu/spi/he/sketches/scalar/phe",
"@yacl//yacl/utils:serializer",
],
)

yacl_cc_library(
name = "encryptor",
srcs = ["encryptor.cc"],
hdrs = ["encryptor.h"],
deps = [
":base",
],
)

yacl_cc_library(
name = "decryptor",
srcs = ["decryptor.cc"],
hdrs = ["decryptor.h"],
deps = [
":base",
],
)

yacl_cc_library(
name = "evaluator",
srcs = ["evaluator.cc"],
hdrs = ["evaluator.h"],
deps = [
":encryptor",
],
)
68 changes: 68 additions & 0 deletions heu/algorithms/ou/base.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "heu/algorithms/ou/base.h"

namespace heu::algos::ou {

namespace {
size_t kExpUnitBits = 10;
} // namespace

void PublicKey::Init() {
MPInt::InvertMod(capital_g_, n_, &capital_g_inv_);

// make cache table
m_space_ = std::make_shared<MontgomerySpace>(n_);
cg_table_ = std::make_shared<BaseTable>();
cgi_table_ = std::make_shared<BaseTable>();
ch_table_ = std::make_shared<BaseTable>();

m_space_->MakeBaseTable(capital_g_, kExpUnitBits,
PlaintextBound().BitCount() - 1, cg_table_.get());
m_space_->MakeBaseTable(capital_g_inv_, kExpUnitBits,
PlaintextBound().BitCount() - 1, cgi_table_.get());
m_space_->MakeBaseTable(capital_h_, kExpUnitBits,
internal_params::kRandomBits3072, ch_table_.get());
}

Plaintext ItemTool::Clone(const Plaintext &pt) const { return pt; }

Ciphertext ItemTool::Clone(const Ciphertext &ct) const {
return Ciphertext(ct.c_);
}

size_t ItemTool::Serialize(const Plaintext &pt, uint8_t *buf,
size_t buf_len) const {
return pt.Serialize(buf, buf_len);
}

size_t ItemTool::Serialize(const Ciphertext &ct, uint8_t *buf,
size_t buf_len) const {
return ct.c_.Serialize(buf, buf_len);
}

Plaintext ItemTool::DeserializePT(yacl::ByteContainerView buffer) const {
Plaintext pt;
pt.Deserialize(buffer);
return pt;
}

Ciphertext ItemTool::DeserializeCT(yacl::ByteContainerView buffer) const {
Ciphertext ct;
ct.c_.Deserialize(buffer);
return ct;
}

} // namespace heu::algos::ou
156 changes: 156 additions & 0 deletions heu/algorithms/ou/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2024 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "yacl/utils/serializer.h"

#include "heu/algorithms/common/type_alias.h"
#include "heu/spi/he/sketches/common/keys.h"
#include "heu/spi/he/sketches/scalar/item_tool.h"

namespace heu::algos::ou {

namespace internal_params {
inline constexpr size_t kRandomBits1024 = 80;
// Note: Why 110, not 112?
// 110 is divisible by kExpUnitBits, which can improve performance
inline constexpr size_t kRandomBits2048 = 110;
inline constexpr size_t kRandomBits3072 = 128;
} // namespace internal_params

using Plaintext = MPInt;

class Ciphertext {
public:
Ciphertext() = default;

explicit Ciphertext(MPInt c) : c_(std::move(c)) {}

bool operator==(const Ciphertext &other) const { return c_ == other.c_; }

bool operator!=(const Ciphertext &other) const {
return !this->operator==(other);
}

MPInt c_;
};

class SecretKey : public spi::KeySketch<spi::HeKeyType::SecretKey> {
public:
MPInt p_, q_; // primes such that log2(p), log2(q) ~ n_bits / 3
MPInt t_; // a big prime factor of p - 1, i.e., p = t * u + 1.
MPInt gp_inv_; // L(g^{p-1} mod p^2))^{-1} mod p

MPInt p2_; // p^2
MPInt p_half_; // p/2
MPInt n_; // n = p^2 * q

bool operator==(const SecretKey &other) const {
return p_ == other.p_ && q_ == other.q_ && t_ == other.t_ &&
gp_inv_ == other.gp_inv_;
}

bool operator!=(const SecretKey &other) const {
return !this->operator==(other);
}

[[nodiscard]] size_t Serialize(uint8_t *buf, size_t buf_len) const {
return yacl::SerializeVarsTo(buf, buf_len, p_, q_, t_, gp_inv_, p2_,
p_half_, n_);
}

static std::shared_ptr<SecretKey> LoadFrom(yacl::ByteContainerView in) {
auto sk = std::make_shared<SecretKey>();
yacl::DeserializeVarsTo(in, &sk->p_, &sk->q_, &sk->t_, &sk->gp_inv_,
&sk->p2_, &sk->p_half_, &sk->n_);
return sk;
}

std::map<std::string, std::string> ListParams() const override {
return {{"p", p_.ToString()}, {"q", q_.ToString()}};
}
};

class PublicKey : public spi::KeySketch<spi::HeKeyType::PublicKey> {
public:
MPInt n_; // n = p^2 * q
MPInt capital_g_; // G = g^u mod n for some random g \in [0, n)
MPInt capital_h_; // H = g'^{n*u} mod n for some random g' \in [0, n)

MPInt capital_g_inv_; // G^{-1} mod n
MPInt max_plaintext_; // always power of 2, e.g. max_plaintext_ == 2^681

std::shared_ptr<MontgomerySpace> m_space_;
// Cache table of bases (底数缓存表).
// Used to speed up PowMod operations
// The cache tables are relatively large (~10+MB), so place them in heap to
// avoid copying the tables when public key is copied
std::shared_ptr<BaseTable> cg_table_; // Auxiliary array for capital_g_
std::shared_ptr<BaseTable> cgi_table_; // Auxiliary array for capital_g_inv_
std::shared_ptr<BaseTable> ch_table_; // Auxiliary array for capital_h_

void Init();

bool operator==(const PublicKey &other) const {
return n_ == other.n_ && capital_g_ == other.capital_g_ &&
capital_h_ == other.capital_h_;
}

bool operator!=(const PublicKey &other) const {
return !this->operator==(other);
}

// Valid plaintext range: [max_plaintext_, -max_plaintext_]
[[nodiscard]] const MPInt &PlaintextBound() const & { return max_plaintext_; }

[[nodiscard]] size_t Serialize(uint8_t *buf, size_t buf_len) const {
return yacl::SerializeVarsTo(buf, buf_len, n_, capital_g_, capital_h_,
max_plaintext_.BitCount() - 1);
}

static std::shared_ptr<PublicKey> LoadFrom(yacl::ByteContainerView in) {
auto pk = std::make_shared<PublicKey>();
size_t max_bits;
yacl::DeserializeVarsTo(in, &pk->n_, &pk->capital_g_, &pk->capital_h_,
&max_bits);
pk->max_plaintext_ = MPInt(1) << max_bits;
pk->Init();
return pk;
}

std::map<std::string, std::string> ListParams() const override {
return {{"n", n_.ToString()},
{"G", capital_g_.ToString()},
{"H", capital_h_.ToString()},
{"max_plaintext", max_plaintext_.ToString()}};
}
};

class ItemTool : public spi::ItemToolScalarSketch<Plaintext, Ciphertext,
SecretKey, PublicKey> {
public:
Plaintext Clone(const Plaintext &pt) const override;
Ciphertext Clone(const Ciphertext &ct) const override;

size_t Serialize(const Plaintext &pt, uint8_t *buf,
size_t buf_len) const override;
size_t Serialize(const Ciphertext &ct, uint8_t *buf,
size_t buf_len) const override;

Plaintext DeserializePT(yacl::ByteContainerView buffer) const override;
Ciphertext DeserializeCT(yacl::ByteContainerView buffer) const override;
};

} // namespace heu::algos::ou
Loading

0 comments on commit b6a08f1

Please sign in to comment.