From 500d454201539a943a1754de450b4635ed1f2464 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Tue, 29 Oct 2024 21:26:38 +0800 Subject: [PATCH 01/11] Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation --- psi/piano/BUILD.bazel | 100 ++++++++++ psi/piano/README.md | 28 +++ psi/piano/client.cc | 346 +++++++++++++++++++++++++++++++++++ psi/piano/client.h | 104 +++++++++++ psi/piano/piano.proto | 30 +++ psi/piano/piano_benchmark.cc | 103 +++++++++++ psi/piano/piano_test.cc | 120 ++++++++++++ psi/piano/serialize.h | 99 ++++++++++ psi/piano/server.cc | 103 +++++++++++ psi/piano/server.h | 54 ++++++ psi/piano/util.cc | 169 +++++++++++++++++ psi/piano/util.h | 88 +++++++++ 12 files changed, 1344 insertions(+) create mode 100644 psi/piano/BUILD.bazel create mode 100644 psi/piano/README.md create mode 100644 psi/piano/client.cc create mode 100644 psi/piano/client.h create mode 100644 psi/piano/piano.proto create mode 100644 psi/piano/piano_benchmark.cc create mode 100644 psi/piano/piano_test.cc create mode 100644 psi/piano/serialize.h create mode 100644 psi/piano/server.cc create mode 100644 psi/piano/server.h create mode 100644 psi/piano/util.cc create mode 100644 psi/piano/util.h diff --git a/psi/piano/BUILD.bazel b/psi/piano/BUILD.bazel new file mode 100644 index 00000000..3336de04 --- /dev/null +++ b/psi/piano/BUILD.bazel @@ -0,0 +1,100 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +proto_library( + name = "piano_proto", + srcs = ["piano.proto"], +) + +cc_proto_library( + name = "piano_cc_proto", + deps = [":piano_proto"], +) + +psi_cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "@yacl//yacl/crypto/aes:aes_intrinsics", + ], +) + +psi_cc_library( + name = "serialize", + srcs = ["serialize.h"], + deps = [ + ":piano_cc_proto", + ":util", + "@yacl//yacl/base:buffer", + ], +) + +psi_cc_library( + name = "server", + srcs = ["server.cc"], + hdrs = ["server.h"], + deps = [ + ":piano_cc_proto", + ":serialize", + ":util", + "@yacl//yacl/link:context", + ], +) + +psi_cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + ":piano_cc_proto", + ":serialize", + ":util", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/link:context", + ], +) + +psi_cc_test( + name = "piano_test", + timeout = "eternal", + srcs = ["piano_test.cc"], + deps = [ + ":client", + ":server", + ":util", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/link:context", + "@yacl//yacl/link:test_util", + ], +) + +psi_cc_binary( + name = "piano_benchmark", + srcs = ["piano_benchmark.cc"], + deps = [ + ":client", + ":server", + ":util", + "@com_github_google_benchmark//:benchmark_main", + "@yacl//yacl/link:context", + "@yacl//yacl/link:test_util", + ], +) diff --git a/psi/piano/README.md b/psi/piano/README.md new file mode 100644 index 00000000..e05a3663 --- /dev/null +++ b/psi/piano/README.md @@ -0,0 +1,28 @@ +Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation + +论文地址:https://eprint.iacr.org/2023/452 + +论文开源实现:https://github.com/wuwuz/Piano-PIR-new + +**方案概括** + +1. 采用特定于客户端的预处理模型(也称为订阅模型),让每个客户端在预处理期间下载并存储来自服务器的“提示” + +2. 实现了O(√n)的客户端存储,和O(√n)的在线通信与计算开销(平均到每个查询上) +3. 服务器是诚实且好奇的,恶意服务器也无法侵害隐私,但会导致查询结果出错 +4. 方案包括两个阶段:预处理阶段和在线查询阶段 + +**预处理阶段** + +1. 服务器将数据库划分为O(√n)个块,客户端**流式**的从服务器获取每块数据,并每次只处理当前块中的元素,包括记录部分数据和计算奇偶校验位 +2. 客户端要存储的“提示”包括三类:主表,替换条目和备份表,主表共有O(√n)个,替换条目和备份表在每个块上要存储 + O(1)个 + +**在线查询阶段** +1. 客户端查询包含x的主表,将主表中的x替换为替换条目中该块的下一个可用元素,发送序列给服务器。服务器计算序列的奇偶校验位并返回,客户端在本地通过异或操作恢复出DB[x] +2. 主表使用一次后就要丢弃,使用备份表进行替换,同时为了负载均衡,要保留(x,DB[x]) + +**具体实现** +1. 客户端不存储完整的序列,而是只存储tag,通过msk和PRF扩展出完整的序列 +2. 每个块都有O(1)个备份表,备份表中该块对应的DB[x]没有参与计算奇偶校验位,以实现快速替换 +3. 本地保存查询记录,当有重复查询时在本地查询,并发送随机序列给服务器 \ No newline at end of file diff --git a/psi/piano/client.cc b/psi/piano/client.cc new file mode 100644 index 00000000..c531e6fd --- /dev/null +++ b/psi/piano/client.cc @@ -0,0 +1,346 @@ +#include "psi/piano/client.h" + +namespace psi::piano { + +uint64_t primaryNumParam(const double q, const double chunk_size, + const double target) { + const double k = std::ceil((std::log(2) * target) + std::log(q)); + return static_cast(k) * static_cast(chunk_size); +} + +double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num, + const uint64_t bin_size) { + const double mean = + static_cast(ball_num) / static_cast(bin_num); + const double c = (static_cast(bin_size) / mean) - 1; + // Chernoff bound exp(-(c^2)/(2+c) * mean) + double t = (mean * (c * c) / (2 + c)) * std::log(2); + t -= std::log2(static_cast(bin_num)); + return t; +} + +QueryServiceClient::QueryServiceClient( + const uint64_t db_size, const uint64_t thread_num, + std::shared_ptr context) + : db_size_(db_size), thread_num_(thread_num), context_(std::move(context)) { + Initialize(); + InitializeLocalSets(); +} + +void QueryServiceClient::Initialize() { + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + + master_key_ = RandKey(rng); + long_key_ = GetLongKey(&master_key_); + + // Q = sqrt(n) * ln(n) + totalQueryNum = + static_cast(std::sqrt(static_cast(db_size_)) * + std::log(static_cast(db_size_))); + + std::tie(chunk_size_, set_size_) = GenParams(db_size_); + + primary_set_num_ = + primaryNumParam(static_cast(totalQueryNum), + static_cast(chunk_size_), FailureProbLog2 + 1); + // if localSetNum is not a multiple of thread_num_ then we need to add some + // padding + primary_set_num_ = + (primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_; + + backup_set_num_per_chunk_ = + 3 * static_cast(static_cast(totalQueryNum) / + static_cast(set_size_)); + backup_set_num_per_chunk_ = + (backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; + + // set_size == chunk_number + total_backup_set_num_ = backup_set_num_per_chunk_ * set_size_; +} + +void QueryServiceClient::InitializeLocalSets() { + primary_sets_.clear(); + primary_sets_.reserve(primary_set_num_); + local_backup_sets_.clear(); + local_backup_sets_.reserve(total_backup_set_num_); + local_cache_.clear(); + local_miss_elements_.clear(); + uint32_t tagCounter = 0; + + for (uint64_t j = 0; j < primary_set_num_; j++) { + primary_sets_.emplace_back(tagCounter, ZeroEntry(), 0, false); + tagCounter += 1; + } + + local_backup_set_groups_.clear(); + local_backup_set_groups_.reserve(set_size_); + local_replacement_groups_.clear(); + local_replacement_groups_.reserve(set_size_); + + for (uint64_t i = 0; i < set_size_; i++) { + std::vector> backupSets; + for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) { + backupSets.emplace_back( + local_backup_sets_[(i * backup_set_num_per_chunk_) + j]); + } + LocalBackupSetGroup backupGroup(0, backupSets); + local_backup_set_groups_.emplace_back(std::move(backupGroup)); + + std::vector indices(backup_set_num_per_chunk_); + std::vector values(backup_set_num_per_chunk_); + LocalReplacementGroup replacementGroup(0, indices, values); + local_replacement_groups_.emplace_back(std::move(replacementGroup)); + } + + for (uint64_t j = 0; j < set_size_; j++) { + for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { + local_backup_set_groups_[j].sets[k].get() = + LocalBackupSet{tagCounter, ZeroEntry()}; + tagCounter += 1; + } + } +} + +void QueryServiceClient::FetchFullDB() { + const auto fetchFullDBMsg = SerializeFetchFullDBMsg(1); + context_->SendAsync(context_->NextRank(), fetchFullDBMsg, "FetchFullDBMsg"); + + for (uint64_t i = 0; i < set_size_; i++) { + auto chunkBuf = context_->Recv(context_->NextRank(), "DBChunk"); + if (chunkBuf.size() == 0) { + break; + } + auto [chunkId, chunkSize, chunk] = DeserializeDBChunk(chunkBuf); + + std::vector hitMap(chunk_size_, false); + + // Use multiple threads to parallelize the computation for the chunk + std::vector threads; + std::mutex hitMapMutex; + + // make sure all sets are covered + const uint64_t perTheadSetNum = + ((primary_set_num_ + thread_num_ - 1) / thread_num_) + 1; + const uint64_t perThreadBackupNum = + ((total_backup_set_num_ + thread_num_ - 1) / thread_num_) + 1; + + for (uint64_t tid = 0; tid < thread_num_; tid++) { + uint64_t startIndex = tid * perTheadSetNum; + uint64_t endIndex = + std::min(startIndex + perTheadSetNum, primary_set_num_); + + uint64_t startIndexBackup = tid * perThreadBackupNum; + uint64_t endIndexBackup = std::min(startIndexBackup + perThreadBackupNum, + total_backup_set_num_); + + threads.emplace_back([&, startIndex, endIndex, startIndexBackup, + endIndexBackup] { + // update the parities for the primary hints + for (uint64_t j = startIndex; j < endIndex; j++) { + const auto tmp = + PRFEvalWithLongKeyAndTag(long_key_, primary_sets_[j].tag, i); + const auto offset = tmp & (chunk_size_ - 1); + { + std::lock_guard lock(hitMapMutex); + hitMap[offset] = true; + } + DBEntryXorFromRaw(&primary_sets_[j].parity, + &chunk[offset * DBEntryLength]); + } + + // update the parities for the backup hints + for (uint64_t j = startIndexBackup; j < endIndexBackup; j++) { + const auto tmp = + PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); + const auto offset = tmp & (chunk_size_ - 1); + DBEntryXorFromRaw(&local_backup_sets_[j].parityAfterPunct, + &chunk[offset * DBEntryLength]); + } + }); + } + + for (auto& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } + + // If any element is not hit, then it is a local miss. We will save it in + // the local miss cache. Most of the time, the local miss cache will be + // empty. + for (uint64_t j = 0; j < chunk_size_; j++) { + if (!hitMap[j]) { + std::array entry_slice{}; + std::memcpy(entry_slice.data(), &chunk[j * DBEntryLength], + DBEntryLength * sizeof(uint64_t)); + const auto entry = DBEntryFromSlice(entry_slice); + local_miss_elements_[j + (i * chunk_size_)] = entry; + } + } + + // For the i-th group of backups, leave the i-th chunk as blank + // To do that, we just xor the i-th chunk's value again + for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { + const auto tag = local_backup_set_groups_[i].sets[k].get().tag; + const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); + const auto offset = tmp & (chunk_size_ - 1); + DBEntryXorFromRaw( + &local_backup_set_groups_[i].sets[k].get().parityAfterPunct, + &chunk[offset * DBEntryLength]); + } + + // store the replacement + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { + // generate a random offset between 0 and ChunkSize - 1 + const auto offset = rng() & (chunk_size_ - 1); + local_replacement_groups_[i].indices[k] = offset + i * chunk_size_; + std::array entry_slice{}; + std::memcpy(entry_slice.data(), &chunk[offset * DBEntryLength], + DBEntryLength * sizeof(uint64_t)); + local_replacement_groups_[i].value[k] = DBEntryFromSlice(entry_slice); + } + } +} + +void QueryServiceClient::SendDummySet() const { + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + std::vector randSet(set_size_); + for (uint64_t i = 0; i < set_size_; i++) { + randSet[i] = rng() % chunk_size_ + i * chunk_size_; + } + + // send the random dummy set to the server + const auto query_msg = SerializeSetParityQueryMsg(set_size_, randSet); + context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); + + const auto response_buf = + context_->Recv(context_->NextRank(), "SetParityQueryResponse"); + auto [parity, server_compute_time] = + DeserializeSetParityQueryResponse(response_buf); +} + +DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { + // make sure x is not in the local cache + if (local_cache_.find(x) != local_cache_.end()) { + SendDummySet(); + return local_cache_[x]; + } + + // 1. Query x: the client first finds a local set that contains x + // 2. The client expands the set, replace the chunk(x)-th element to a + // replacement + // 3. The client sends the edited set to the server and gets the parity + // 4. The client recovers the answer + uint64_t hitSetId = std::numeric_limits::max(); + + const uint64_t queryOffset = x % chunk_size_; + const uint64_t chunkId = x / chunk_size_; + + for (uint64_t i = 0; i < primary_set_num_; i++) { + const auto& set = primary_sets_[i]; + if (const bool isProgrammedMatch = + set.isProgrammed && chunkId == (set.programmedPoint / chunk_size_); + !isProgrammedMatch && + PRSetWithShortTag{set.tag}.MemberTestWithLongKeyAndTag( + long_key_, chunkId, queryOffset, chunk_size_)) { + hitSetId = i; + break; + } + } + + DBEntry xVal = ZeroEntry(); + + if (hitSetId == std::numeric_limits::max()) { + if (local_miss_elements_.find(x) == local_miss_elements_.end()) { + SPDLOG_ERROR("No hit set found for %lu", x); + } else { + xVal = local_miss_elements_[x]; + local_cache_[x] = xVal; + } + + SendDummySet(); + return xVal; + } + + // expand the set + const PRSetWithShortTag set{primary_sets_[hitSetId].tag}; + auto expandedSet = set.ExpandWithLongKey(long_key_, set_size_, chunk_size_); + + // manually program the set if the flag is set before + if (primary_sets_[hitSetId].isProgrammed) { + const uint64_t programmedChunkId = + primary_sets_[hitSetId].programmedPoint / chunk_size_; + expandedSet[programmedChunkId] = primary_sets_[hitSetId].programmedPoint; + } + + // edit the set by replacing the chunk(x)-th element with a replacement + const uint64_t nxtAvailable = local_replacement_groups_[chunkId].consumed; + if (nxtAvailable == backup_set_num_per_chunk_) { + SPDLOG_ERROR("No replacement available for %lu", x); + SendDummySet(); + return xVal; + } + + // consume one replacement + const uint64_t repIndex = + local_replacement_groups_[chunkId].indices[nxtAvailable]; + const DBEntry repVal = local_replacement_groups_[chunkId].value[nxtAvailable]; + local_replacement_groups_[chunkId].consumed++; + expandedSet[chunkId] = repIndex; + + // send the edited set to the server + const auto query_msg = SerializeSetParityQueryMsg(set_size_, expandedSet); + context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); + + const auto response_buf = + context_->Recv(context_->NextRank(), "SetParityQueryResponse"); + auto [parity, server_compute_time] = + DeserializeSetParityQueryResponse(response_buf); + + // recover the answer + xVal = primary_sets_[hitSetId].parity; // the parity of the hit set + DBEntryXorFromRaw(&xVal, parity.data()); // xor the parity of the edited set + DBEntryXor(&xVal, &repVal); // xor the replacement value + + // update the local cache + local_cache_[x] = xVal; + + // refresh phase + if (local_backup_set_groups_[chunkId].consumed == backup_set_num_per_chunk_) { + SPDLOG_WARN("No backup set available for %lu", x); + return xVal; + } + + const DBEntry originalXVal = xVal; + const uint64_t consumed = local_backup_set_groups_[chunkId].consumed; + primary_sets_[hitSetId].tag = + local_backup_set_groups_[chunkId].sets[consumed].get().tag; + // backup set doesn't XOR the chunk(x)-th element in preparation + DBEntryXor( + &xVal, + &local_backup_set_groups_[chunkId].sets[consumed].get().parityAfterPunct); + primary_sets_[hitSetId].parity = xVal; + primary_sets_[hitSetId].isProgrammed = true; + // for load balancing, the chunk(x)-th element differs from the one expanded + // via PRFEval on the tag + primary_sets_[hitSetId].programmedPoint = x; + local_backup_set_groups_[chunkId].consumed++; + + return originalXVal; +} + +std::vector QueryServiceClient::OnlineMultipleQueries( + const std::vector& queries) { + std::vector results; + results.reserve(queries.size()); + + for (const auto& x : queries) { + DBEntry result = OnlineSingleQuery(x); + results.push_back(result); + } + + return results; +} + +} // namespace psi::piano diff --git a/psi/piano/client.h b/psi/piano/client.h new file mode 100644 index 00000000..b2089b8b --- /dev/null +++ b/psi/piano/client.h @@ -0,0 +1,104 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "yacl/crypto/rand/rand.h" +#include "yacl/link/context.h" + +#include "psi/piano/serialize.h" +#include "psi/piano/util.h" + +namespace psi::piano { + +class LocalSet { + public: + uint32_t tag; // the tag of the set + DBEntry parity; + uint64_t + programmedPoint; // identifier for the element replaced after refresh, + // differing from those expanded by PRFEval + bool isProgrammed; + + LocalSet(const uint32_t tag, const DBEntry& parity, + const uint64_t programmed_point, const bool is_programmed) + : tag(tag), + parity(parity), + programmedPoint(programmed_point), + isProgrammed(is_programmed) {} +}; + +class LocalBackupSet { + public: + uint32_t tag; + DBEntry parityAfterPunct; + + LocalBackupSet(const uint32_t tag, const DBEntry& parity_after_punct) + : tag(tag), parityAfterPunct(parity_after_punct) {} +}; + +class LocalBackupSetGroup { + public: + uint64_t consumed; + std::vector> sets; + + LocalBackupSetGroup( + const uint64_t consumed, + const std::vector>& sets) + : consumed(consumed), sets(sets) {} +}; + +class LocalReplacementGroup { + public: + uint64_t consumed; + std::vector indices; + std::vector value; + + LocalReplacementGroup(const uint64_t consumed, + const std::vector& indices, + const std::vector& value) + : consumed(consumed), indices(indices), value(value) {} +}; + +class QueryServiceClient { + public: + static constexpr uint64_t FailureProbLog2 = 40; + uint64_t totalQueryNum{}; + + QueryServiceClient(uint64_t db_size, uint64_t thread_num, + std::shared_ptr context); + + void Initialize(); + void InitializeLocalSets(); + void FetchFullDB(); + void SendDummySet() const; + DBEntry OnlineSingleQuery(uint64_t x); + std::vector OnlineMultipleQueries( + const std::vector& queries); + + private: + uint64_t db_size_; + uint64_t thread_num_; + std::shared_ptr context_; + + uint64_t chunk_size_{}; + uint64_t set_size_{}; + uint64_t primary_set_num_{}; + uint64_t backup_set_num_per_chunk_{}; + uint64_t total_backup_set_num_{}; + PrfKey master_key_{}; + yacl::crypto::AES_KEY long_key_{}; + + std::vector primary_sets_; + std::vector local_backup_sets_; + std::map local_cache_; + std::map local_miss_elements_; + std::vector local_backup_set_groups_; + std::vector local_replacement_groups_; +}; + +} // namespace psi::piano diff --git a/psi/piano/piano.proto b/psi/piano/piano.proto new file mode 100644 index 00000000..aa9cac3f --- /dev/null +++ b/psi/piano/piano.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package psi.piano; + +message FetchFullDbMsg { + uint64 dummy = 1; +} + +message DbChunk { + uint64 chunk_id = 1; + uint64 chunk_size = 2; + repeated uint64 chunks = 3; +} + +message SetParityQueryMsg { + uint64 set_size = 1; + repeated uint64 indices = 2; +} + +message SetParityQueryResponse { + repeated uint64 parity = 1; + uint64 server_compute_time = 2; +} + +message QueryRequest { + oneof request { + FetchFullDbMsg fetch_full_db = 1; + SetParityQueryMsg set_parity_query = 2; + } +} diff --git a/psi/piano/piano_benchmark.cc b/psi/piano/piano_benchmark.cc new file mode 100644 index 00000000..2c16d08d --- /dev/null +++ b/psi/piano/piano_benchmark.cc @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "yacl/link/context.h" +#include "yacl/link/test_util.h" + +#include "psi/piano/client.h" +#include "psi/piano/server.h" +#include "psi/piano/util.h" + +namespace { + +std::vector GenerateQueries(const uint64_t query_num, + const uint64_t db_size) { + std::vector queries; + queries.reserve(query_num); + + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + for (uint64_t q = 0; q < query_num; ++q) { + queries.push_back(rng() % db_size); + } + + return queries; +} + +std::vector CreateDatabase(const uint64_t db_size, + const uint64_t db_seed) { + const auto [ChunkSize, SetSize] = psi::piano::GenParams(db_size); + std::vector DB; + DB.assign(ChunkSize * SetSize * psi::piano::DBEntryLength, 0); + + for (uint64_t i = 0; i < DB.size() / psi::piano::DBEntryLength; ++i) { + auto entry = psi::piano::GenDBEntry(db_seed, i); + std::memcpy(&DB[i * psi::piano::DBEntryLength], entry.data(), + psi::piano::DBEntryLength * sizeof(uint64_t)); + } + + return DB; +} + +void SetupAndRunServer( + const std::shared_ptr& server_context, + const uint64_t db_size, std::promise& exit_signal, + std::vector& db) { + const auto [ChunkSize, SetSize] = psi::piano::GenParams(db_size); + psi::piano::QueryServiceServer server(db, server_context, SetSize, ChunkSize); + server.Start(exit_signal.get_future()); +} + +std::vector SetupAndRunClient( + const uint64_t db_size, const uint64_t thread_num, + const std::shared_ptr& client_context, + const std::vector& queries) { + psi::piano::QueryServiceClient client(db_size, thread_num, client_context); + client.FetchFullDB(); + return client.OnlineMultipleQueries(queries); +} + +} // namespace + +static void BM_PianoPir(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + uint64_t db_size = state.range(0) / sizeof(psi::piano::DBEntry); + const uint64_t query_num = state.range(1); + constexpr uint64_t db_seed = 2315127; + uint64_t thread_num = 8; + + constexpr int kWorldSize = 2; + const auto contexts = yacl::link::test::SetupWorld(kWorldSize); + yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); + + auto db = CreateDatabase(db_size, db_seed); + auto queries = GenerateQueries(query_num, db_size); + + state.ResumeTiming(); + std::promise exitSignal; + auto server_future = + std::async(std::launch::async, SetupAndRunServer, contexts[0], db_size, + std::ref(exitSignal), std::ref(db)); + + auto client_future = + std::async(std::launch::async, SetupAndRunClient, db_size, thread_num, + contexts[1], std::cref(queries)); + auto results = client_future.get(); + + exitSignal.set_value(); + server_future.get(); + } +} + +// [1m, 16m, 64m, 128m] +BENCHMARK(BM_PianoPir) + ->Unit(benchmark::kMillisecond) + ->Args({1 << 20, 1000}) + ->Args({16 << 20, 1000}) + ->Args({64 << 20, 1000}) + ->Args({128 << 20, 1000}); diff --git a/psi/piano/piano_test.cc b/psi/piano/piano_test.cc new file mode 100644 index 00000000..9b9d404b --- /dev/null +++ b/psi/piano/piano_test.cc @@ -0,0 +1,120 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "yacl/link/context.h" +#include "yacl/link/test_util.h" + +#include "psi/piano/client.h" +#include "psi/piano/serialize.h" +#include "psi/piano/server.h" +#include "psi/piano/util.h" + +struct TestParams { + uint64_t db_size; + uint64_t db_seed; + uint64_t thread_num; + uint64_t query_num; + bool is_total_query_num; +}; + +namespace psi::piano { + +std::vector GenerateQueries(const uint64_t query_num, + const uint64_t db_size) { + std::vector queries; + queries.reserve(query_num); + + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + for (uint64_t q = 0; q < query_num; ++q) { + queries.push_back(rng() % db_size); + } + + return queries; +} + +std::vector RunClient(QueryServiceClient& client, + const std::vector& queries) { + client.FetchFullDB(); + return client.OnlineMultipleQueries(queries); +} + +std::vector getResults(const std::vector& queries, + const TestParams& params) { + std::vector expected_results; + expected_results.reserve(queries.size()); + + for (const auto& x : queries) { + expected_results.push_back(GenDBEntry(params.db_seed, x)); + } + + return expected_results; +} + +class PianoTest : public testing::TestWithParam {}; + +TEST_P(PianoTest, Works) { + auto params = GetParam(); + constexpr int kWorldSize = 2; + const auto contexts = yacl::link::test::SetupWorld(kWorldSize); + + SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", + params.db_size, DBEntrySize, + params.db_size * DBEntrySize / 1024 / 1024); + + auto [ChunkSize, SetSize] = GenParams(params.db_size); + SPDLOG_INFO("Chunk Size: %lu, Set Size: %lu\n", ChunkSize, SetSize); + + std::vector DB; + DB.assign(ChunkSize * SetSize * DBEntryLength, 0); + SPDLOG_INFO("DB Real N: %lu\n", DB.size()); + + for (uint64_t i = 0; i < DB.size() / DBEntryLength; ++i) { + auto entry = GenDBEntry(params.db_seed, i); + std::memcpy(&DB[i * DBEntryLength], entry.data(), + DBEntryLength * sizeof(uint64_t)); + } + + QueryServiceClient client(params.db_size, params.thread_num, contexts[1]); + + const auto actual_query_num = + params.is_total_query_num ? client.totalQueryNum : params.query_num; + auto queries = GenerateQueries(actual_query_num, DB.size()); + + yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); + QueryServiceServer server(DB, contexts[0], SetSize, ChunkSize); + + std::promise exitSignal; + std::future futureObj = exitSignal.get_future(); + auto server_future = + std::async(std::launch::async, &QueryServiceServer::Start, + std::ref(server), std::move(futureObj)); + auto client_future = std::async(std::launch::async, RunClient, + std::ref(client), std::cref(queries)); + + auto results = client_future.get(); + auto expected_results = getResults(queries, params); + + for (size_t i = 0; i < results.size(); ++i) { + EXPECT_EQ(results[i], expected_results[i]) + << "Mismatch at index " << queries[i]; + } + + exitSignal.set_value(); + server_future.get(); +} + +// [8m, 128m, 1G] +INSTANTIATE_TEST_SUITE_P( + PianoTestInstances, PianoTest, + ::testing::Values(TestParams{131072, 1211212, 8, 1000, false}, + TestParams{2097152, 6405285, 8, 1000, false}, + TestParams{16777216, 7539870, 16, 1000, false})); +} // namespace psi::piano diff --git a/psi/piano/serialize.h b/psi/piano/serialize.h new file mode 100644 index 00000000..c09d1cbd --- /dev/null +++ b/psi/piano/serialize.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include + +#include "yacl/base/buffer.h" + +#include "psi/piano/util.h" + +#include "psi/piano/piano.pb.h" + +namespace psi::piano { + +inline yacl::Buffer SerializeFetchFullDBMsg(const uint64_t dummy) { + QueryRequest proto; + FetchFullDbMsg* fetch_full_db_msg = proto.mutable_fetch_full_db(); + fetch_full_db_msg->set_dummy(dummy); + + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + + return buf; +} + +inline uint64_t DeserializeFetchFullDBMsg(const yacl::Buffer& buf) { + QueryRequest proto; + proto.ParseFromArray(buf.data(), buf.size()); + return proto.fetch_full_db().dummy(); +} + +inline yacl::Buffer SerializeDBChunk(const uint64_t chunk_id, + const uint64_t chunk_size, + const std::vector& chunk) { + DbChunk proto; + proto.set_chunk_id(chunk_id); + proto.set_chunk_size(chunk_size); + for (const auto& val : chunk) { + proto.add_chunks(val); + } + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + return buf; +} + +inline std::tuple> DeserializeDBChunk( + const yacl::Buffer& buf) { + DbChunk proto; + proto.ParseFromArray(buf.data(), buf.size()); + std::vector chunk(proto.chunks().begin(), proto.chunks().end()); + return {proto.chunk_id(), proto.chunk_size(), chunk}; +} + +inline yacl::Buffer SerializeSetParityQueryMsg( + const uint64_t set_size, const std::vector& indices) { + QueryRequest proto; + SetParityQueryMsg* set_parity_query = proto.mutable_set_parity_query(); + set_parity_query->set_set_size(set_size); + for (const auto& index : indices) { + set_parity_query->add_indices(index); + } + + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + + return buf; +} + +inline std::pair> DeserializeSetParityQueryMsg( + const yacl::Buffer& buf) { + QueryRequest proto; + proto.ParseFromArray(buf.data(), buf.size()); + const auto& set_parity_query = proto.set_parity_query(); + std::vector indices(set_parity_query.indices().begin(), + set_parity_query.indices().end()); + return {set_parity_query.set_size(), indices}; +} + +inline yacl::Buffer SerializeSetParityQueryResponse( + const std::vector& parity, const uint64_t server_compute_time) { + SetParityQueryResponse proto; + for (const auto& p : parity) { + proto.add_parity(p); + } + proto.set_server_compute_time(server_compute_time); + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + return buf; +} + +inline std::pair, uint64_t> +DeserializeSetParityQueryResponse(const yacl::Buffer& buf) { + SetParityQueryResponse proto; + proto.ParseFromArray(buf.data(), buf.size()); + std::vector parity(proto.parity().begin(), proto.parity().end()); + return {parity, proto.server_compute_time()}; +} + +} // namespace psi::piano diff --git a/psi/piano/server.cc b/psi/piano/server.cc new file mode 100644 index 00000000..51926179 --- /dev/null +++ b/psi/piano/server.cc @@ -0,0 +1,103 @@ +#include "psi/piano/server.h" + +namespace psi::piano { + +QueryServiceServer::QueryServiceServer( + std::vector& db, std::shared_ptr context, + const uint64_t set_size, const uint64_t chunk_size) + : db_(std::move(db)), + context_(std::move(context)), + set_size_(set_size), + chunk_size_(chunk_size) {} + +void QueryServiceServer::Start(const std::future& stop_signal) { + while (stop_signal.wait_for(std::chrono::milliseconds(1)) == + std::future_status::timeout) { + auto request_data = context_->Recv(context_->NextRank(), "request_data"); + HandleRequest(request_data); + } +} + +void QueryServiceServer::HandleRequest(const yacl::Buffer& request_data) { + QueryRequest proto; + proto.ParseFromArray(request_data.data(), request_data.size()); + + switch (proto.request_case()) { + case QueryRequest::kFetchFullDb: { + // uint64_t dummy = DeserializeFetchFullDBMsg(request_data); + ProcessFetchFullDB(); + break; + } + case QueryRequest::kSetParityQuery: { + auto [set_size, indices] = DeserializeSetParityQueryMsg(request_data); + auto [parity, server_compute_time] = ProcessSetParityQuery(indices); + const auto response_buf = + SerializeSetParityQueryResponse(parity, server_compute_time); + context_->SendAsync(context_->NextRank(), response_buf, + "SetParityQueryResponse"); + break; + } + default: + SPDLOG_ERROR("Unknown request type."); + } +} + +void QueryServiceServer::ProcessFetchFullDB() { + for (uint64_t i = 0; i < set_size_; ++i) { + const uint64_t down = i * chunk_size_; + uint64_t up = (i + 1) * chunk_size_; + + up = std::min(up, db_.size()); + std::vector chunk(db_.begin() + down * DBEntryLength, + db_.begin() + up * DBEntryLength); + auto chunk_buf = SerializeDBChunk(i, chunk.size(), chunk); + + try { + context_->SendAsync(context_->NextRank(), chunk_buf, "FetchFullDBChunk"); + } catch (const std::exception& e) { + SPDLOG_ERROR("Failed to send a chunk."); + return; + } + } +} + +std::pair, uint64_t> +QueryServiceServer::ProcessSetParityQuery( + const std::vector& indices) { + const auto start = std::chrono::high_resolution_clock::now(); + std::vector parity = HandleSetParityQuery(indices); + const auto end = std::chrono::high_resolution_clock::now(); + const auto duration = + std::chrono::duration_cast(end - start).count(); + return {parity, duration}; +} + +DBEntry QueryServiceServer::DBAccess(const uint64_t id) { + if (id < db_.size()) { + if (id * DBEntryLength + DBEntryLength > db_.size()) { + SPDLOG_ERROR("DBAccess: id {} out of range", id); + } + std::array slice{}; + std::copy(db_.begin() + id * DBEntryLength, + db_.begin() + (id + 1) * DBEntryLength, slice.begin()); + return DBEntryFromSlice(slice); + } + DBEntry ret; + ret.fill(0); + return ret; +} + +std::vector QueryServiceServer::HandleSetParityQuery( + const std::vector& indices) { + DBEntry parity = ZeroEntry(); + for (const auto& index : indices) { + DBEntry entry = DBAccess(index); + DBEntryXor(&parity, &entry); + } + + std::vector ret(DBEntryLength); + std::copy(parity.begin(), parity.end(), ret.begin()); + return ret; +} + +} // namespace psi::piano diff --git a/psi/piano/server.h b/psi/piano/server.h new file mode 100644 index 00000000..9abe93d6 --- /dev/null +++ b/psi/piano/server.h @@ -0,0 +1,54 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "yacl/link/context.h" + +#include "psi/piano/serialize.h" +#include "psi/piano/util.h" + +namespace psi::piano { + +class QueryServiceServer { + public: + // Constructor: initializes the server with a database, context, set_size, and + // chunk_size + QueryServiceServer(std::vector& db, + std::shared_ptr context, + uint64_t set_size, uint64_t chunk_size); + + // Starts the server to handle incoming requests + void Start(const std::future& stop_signal); + + // Handles the incoming request based on its type + void HandleRequest(const yacl::Buffer& request_data); + + // Processes a request to fetch the full database + void ProcessFetchFullDB(); + + // Processes a set parity query and returns the parity and server compute time + std::pair, uint64_t> ProcessSetParityQuery( + const std::vector& indices); + + private: + // Accesses the database and returns the corresponding entry + DBEntry DBAccess(uint64_t id); + + // Handles a set parity query and returns the parity + std::vector HandleSetParityQuery( + const std::vector& indices); + + std::vector db_; // The database + std::shared_ptr context_; // The communication context + uint64_t set_size_; // The size of the set + uint64_t chunk_size_; // The size of each chunk +}; + +} // namespace psi::piano diff --git a/psi/piano/util.cc b/psi/piano/util.cc new file mode 100644 index 00000000..6ed5a9c7 --- /dev/null +++ b/psi/piano/util.cc @@ -0,0 +1,169 @@ +#include "psi/piano/util.h" + +namespace psi::piano { + +uint128_t BytesToUint128(const std::string& bytes) { + if (bytes.size() != 16) { + SPDLOG_WARN("Bytes size must be 16 for uint128_t conversion."); + } + + uint128_t result = 0; + std::memcpy(&result, bytes.data(), 16); + return result; +} + +std::string Uint128ToBytes(const uint128_t value) { + std::string bytes(16, 0); + std::memcpy(bytes.data(), &value, 16); + return bytes; +} + +PrfKey128 RandKey128(std::mt19937_64& rng) { + const uint64_t lo = rng(); + const uint64_t hi = rng(); + return yacl::MakeUint128(hi, lo); +} + +PrfKey RandKey(std::mt19937_64& rng) { return RandKey128(rng); } + +uint64_t PRFEval128(const PrfKey128* key, const uint64_t x) { + yacl::crypto::AES_KEY aes_key; + AES_set_encrypt_key(*key, &aes_key); + + const auto src_block = static_cast(x); + std::vector plain_blocks(1); + plain_blocks[0] = src_block; + std::vector cipher_blocks(1); + + AES_ecb_encrypt_blks(aes_key, absl::MakeConstSpan(plain_blocks), + absl::MakeSpan(cipher_blocks)); + return static_cast(cipher_blocks[0]); +} + +uint64_t PRFEval(const PrfKey* key, const uint64_t x) { + return PRFEval128(key, x); +} + +void DBEntryXor(DBEntry* dst, const DBEntry* src) { + for (size_t i = 0; i < DBEntryLength; ++i) { + (*dst)[i] ^= (*src)[i]; + } +} + +void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src) { + for (size_t i = 0; i < DBEntryLength; ++i) { + (*dst)[i] ^= src[i]; + } +} + +bool EntryIsEqual(const DBEntry& a, const DBEntry& b) { + for (size_t i = 0; i < DBEntryLength; ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; +} + +DBEntry RandDBEntry(std::mt19937_64& rng) { + DBEntry entry; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = rng(); + } + return entry; +} + +uint64_t DefaultHash(uint64_t key) { + constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; + uint64_t hash = FNV_offset_basis; + for (int i = 0; i < 8; ++i) { + constexpr uint64_t FNV_prime = 1099511628211ULL; + const auto byte = static_cast(key & 0xFF); + hash ^= static_cast(byte); + hash *= FNV_prime; + key >>= 8; + } + return hash; +} + +DBEntry GenDBEntry(const uint64_t key, const uint64_t id) { + DBEntry entry; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = DefaultHash((key ^ id) + i); + } + return entry; +} + +DBEntry ZeroEntry() { + DBEntry entry = {}; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = 0; + } + return entry; +} + +DBEntry DBEntryFromSlice(const std::array& s) { + DBEntry entry; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = s[i]; + } + return entry; +} + +// Generate ChunkSize and SetSize +std::pair GenParams(const uint64_t db_size) { + const double targetChunkSize = 2 * std::sqrt(static_cast(db_size)); + uint64_t ChunkSize = 1; + + // Ensure ChunkSize is a power of 2 and not smaller than targetChunkSize + while (ChunkSize < static_cast(targetChunkSize)) { + ChunkSize *= 2; + } + + uint64_t SetSize = (db_size + ChunkSize - 1) / ChunkSize; + // Round up to the next multiple of 4 + SetSize = (SetSize + 3) / 4 * 4; + + return {ChunkSize, SetSize}; +} + +yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key) { + yacl::crypto::AES_KEY aes_key; + AES_set_encrypt_key(*key, &aes_key); + return aes_key; +} + +uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, + const uint32_t tag, const uint64_t x) { + // Combine tag and x into a 128-bit block by shifting tag to the high 64 bits + const uint128_t src_block = (static_cast(tag) << 64) + x; + std::vector plain_blocks(1); + plain_blocks[0] = src_block; + std::vector cipher_blocks(1); + AES_ecb_encrypt_blks(long_key, absl::MakeConstSpan(plain_blocks), + absl::MakeSpan(cipher_blocks)); + return static_cast(cipher_blocks[0]); +} + +std::vector PRSetWithShortTag::ExpandWithLongKey( + const yacl::crypto::AES_KEY& long_key, const uint64_t set_size, + const uint64_t chunk_size) const { + std::vector expandedSet(set_size); + for (uint64_t i = 0; i < set_size; i++) { + const uint64_t tmp = PRFEvalWithLongKeyAndTag(long_key, Tag, i); + // Get the offset within the chunk + const uint64_t offset = tmp & (chunk_size - 1); + expandedSet[i] = i * chunk_size + offset; + } + return expandedSet; +} + +bool PRSetWithShortTag::MemberTestWithLongKeyAndTag( + const yacl::crypto::AES_KEY& long_key, const uint64_t chunk_id, + const uint64_t offset, const uint64_t chunk_size) const { + // Ensure chunk_size is a power of 2 and compare offsets + return offset == + (PRFEvalWithLongKeyAndTag(long_key, Tag, chunk_id) & (chunk_size - 1)); +} + +} // namespace psi::piano diff --git a/psi/piano/util.h b/psi/piano/util.h new file mode 100644 index 00000000..9158224c --- /dev/null +++ b/psi/piano/util.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "yacl/crypto/aes/aes_intrinsics.h" + +namespace psi::piano { + +constexpr size_t DBEntrySize = 8; // has to be a multiple of 8 +constexpr size_t DBEntryLength = DBEntrySize / 8; + +using PrfKey128 = uint128_t; +using DBEntry = std::array; +using PrfKey = PrfKey128; + +uint128_t BytesToUint128(const std::string& bytes); + +std::string Uint128ToBytes(uint128_t value); + +// Generates a random 128-bit key using the provided RNG +PrfKey128 RandKey128(std::mt19937_64& rng); + +// Generates a random PRF key +PrfKey RandKey(std::mt19937_64& rng); + +// Evaluates PRF using 128-bit key and returns a 64-bit result +uint64_t PRFEval128(const PrfKey128* key, uint64_t x); + +// Evaluates PRF using a general PrfKey and returns a 64-bit result +uint64_t PRFEval(const PrfKey* key, uint64_t x); + +// XOR two DBEntry structures +void DBEntryXor(DBEntry* dst, const DBEntry* src); + +// XOR a DBEntry with raw uint64_t data +void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src); + +// Compare two DBEntry structures for equality +bool EntryIsEqual(const DBEntry& a, const DBEntry& b); + +// Generate a random DBEntry using the provided RNG +DBEntry RandDBEntry(std::mt19937_64& rng); + +// Default FNV hash implementation for 64-bit keys +uint64_t DefaultHash(uint64_t key); + +// Generate a DBEntry based on a key and ID +DBEntry GenDBEntry(uint64_t key, uint64_t id); + +// Generate a zero-filled DBEntry +DBEntry ZeroEntry(); + +// Convert a slice (array) into a DBEntry structure +DBEntry DBEntryFromSlice(const std::array& s); + +// Generate parameters for ChunkSize and SetSize +std::pair GenParams(uint64_t db_size); + +// Returns a long key (AES expanded key) for PRF evaluation +yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key); + +// PRF evaluation with a long key and tag, returns a 64-bit result +uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, + uint32_t tag, uint64_t x); + +class PRSetWithShortTag { + public: + uint32_t Tag; + + // Expands the set with a long key and tag + [[nodiscard]] std::vector ExpandWithLongKey( + const yacl::crypto::AES_KEY& long_key, uint64_t set_size, + uint64_t chunk_size) const; + + // Membership test with a long key and tag, to check if an ID belongs to the + // set + [[nodiscard]] bool MemberTestWithLongKeyAndTag( + const yacl::crypto::AES_KEY& long_key, uint64_t chunk_id, uint64_t offset, + uint64_t chunk_size) const; +}; + +} // namespace psi::piano From 22bab6070862e44d3fb4cc8f36394f3d08556091 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Tue, 29 Oct 2024 21:40:34 +0800 Subject: [PATCH 02/11] Modify Readme format. --- psi/piano/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/psi/piano/README.md b/psi/piano/README.md index e05a3663..65cc6d99 100644 --- a/psi/piano/README.md +++ b/psi/piano/README.md @@ -25,4 +25,4 @@ Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation **具体实现** 1. 客户端不存储完整的序列,而是只存储tag,通过msk和PRF扩展出完整的序列 2. 每个块都有O(1)个备份表,备份表中该块对应的DB[x]没有参与计算奇偶校验位,以实现快速替换 -3. 本地保存查询记录,当有重复查询时在本地查询,并发送随机序列给服务器 \ No newline at end of file +3. 本地保存查询记录,当有重复查询时在本地查询,并发送随机序列给服务器 From 057d545ce7ab1159d517064b7d39f309b2f48621 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Tue, 29 Oct 2024 22:14:50 +0800 Subject: [PATCH 03/11] Modify for unittest --- psi/piano/piano_test.cc | 4 ++-- psi/piano/server.cc | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/psi/piano/piano_test.cc b/psi/piano/piano_test.cc index 9b9d404b..9fe33643 100644 --- a/psi/piano/piano_test.cc +++ b/psi/piano/piano_test.cc @@ -111,10 +111,10 @@ TEST_P(PianoTest, Works) { server_future.get(); } -// [8m, 128m, 1G] +// [8m, 128m, 256m] INSTANTIATE_TEST_SUITE_P( PianoTestInstances, PianoTest, ::testing::Values(TestParams{131072, 1211212, 8, 1000, false}, TestParams{2097152, 6405285, 8, 1000, false}, - TestParams{16777216, 7539870, 16, 1000, false})); + TestParams{4194304, 7539870, 16, 1000, false})); } // namespace psi::piano diff --git a/psi/piano/server.cc b/psi/piano/server.cc index 51926179..12f75cb5 100644 --- a/psi/piano/server.cc +++ b/psi/piano/server.cc @@ -46,8 +46,7 @@ void QueryServiceServer::ProcessFetchFullDB() { for (uint64_t i = 0; i < set_size_; ++i) { const uint64_t down = i * chunk_size_; uint64_t up = (i + 1) * chunk_size_; - - up = std::min(up, db_.size()); + up = std::min(up, static_cast(db_.size())); std::vector chunk(db_.begin() + down * DBEntryLength, db_.begin() + up * DBEntryLength); auto chunk_buf = SerializeDBChunk(i, chunk.size(), chunk); From cd4cc90673b8ae8f68ec277ca095aec1d748c438 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Wed, 30 Oct 2024 13:20:29 +0800 Subject: [PATCH 04/11] Modify for macos_ut --- psi/piano/client.cc | 10 ++++++---- psi/piano/server.cc | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/psi/piano/client.cc b/psi/piano/client.cc index c531e6fd..9d8475ac 100644 --- a/psi/piano/client.cc +++ b/psi/piano/client.cc @@ -110,7 +110,8 @@ void QueryServiceClient::FetchFullDB() { if (chunkBuf.size() == 0) { break; } - auto [chunkId, chunkSize, chunk] = DeserializeDBChunk(chunkBuf); + auto dbChunk = DeserializeDBChunk(chunkBuf); + auto& chunk = std::get<2>(dbChunk); std::vector hitMap(chunk_size_, false); @@ -216,8 +217,7 @@ void QueryServiceClient::SendDummySet() const { const auto response_buf = context_->Recv(context_->NextRank(), "SetParityQueryResponse"); - auto [parity, server_compute_time] = - DeserializeSetParityQueryResponse(response_buf); + // auto parityQueryResponse = DeserializeSetParityQueryResponse(response_buf); } DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { @@ -295,8 +295,10 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { const auto response_buf = context_->Recv(context_->NextRank(), "SetParityQueryResponse"); - auto [parity, server_compute_time] = + + const auto parityQueryResponse = DeserializeSetParityQueryResponse(response_buf); + const auto& parity = std::get<0>(parityQueryResponse); // recover the answer xVal = primary_sets_[hitSetId].parity; // the parity of the hit set diff --git a/psi/piano/server.cc b/psi/piano/server.cc index 12f75cb5..2c835296 100644 --- a/psi/piano/server.cc +++ b/psi/piano/server.cc @@ -29,7 +29,9 @@ void QueryServiceServer::HandleRequest(const yacl::Buffer& request_data) { break; } case QueryRequest::kSetParityQuery: { - auto [set_size, indices] = DeserializeSetParityQueryMsg(request_data); + const auto parityQuery = DeserializeSetParityQueryMsg(request_data); + const auto& indices = std::get<1>(parityQuery); + auto [parity, server_compute_time] = ProcessSetParityQuery(indices); const auto response_buf = SerializeSetParityQueryResponse(parity, server_compute_time); From 0f90718559457e1527b57c3c2b38168d31d1aba0 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Mon, 4 Nov 2024 11:26:13 +0800 Subject: [PATCH 05/11] Move to experimental/pir/piano --- experimental/pir/piano/BUILD.bazel | 100 +++++++ experimental/pir/piano/README.md | 28 ++ experimental/pir/piano/client.cc | 348 ++++++++++++++++++++++ experimental/pir/piano/client.h | 104 +++++++ experimental/pir/piano/piano.proto | 30 ++ experimental/pir/piano/piano_benchmark.cc | 103 +++++++ experimental/pir/piano/piano_test.cc | 120 ++++++++ experimental/pir/piano/serialize.h | 99 ++++++ experimental/pir/piano/server.cc | 104 +++++++ experimental/pir/piano/server.h | 54 ++++ experimental/pir/piano/util.cc | 169 +++++++++++ experimental/pir/piano/util.h | 88 ++++++ 12 files changed, 1347 insertions(+) create mode 100644 experimental/pir/piano/BUILD.bazel create mode 100644 experimental/pir/piano/README.md create mode 100644 experimental/pir/piano/client.cc create mode 100644 experimental/pir/piano/client.h create mode 100644 experimental/pir/piano/piano.proto create mode 100644 experimental/pir/piano/piano_benchmark.cc create mode 100644 experimental/pir/piano/piano_test.cc create mode 100644 experimental/pir/piano/serialize.h create mode 100644 experimental/pir/piano/server.cc create mode 100644 experimental/pir/piano/server.h create mode 100644 experimental/pir/piano/util.cc create mode 100644 experimental/pir/piano/util.h diff --git a/experimental/pir/piano/BUILD.bazel b/experimental/pir/piano/BUILD.bazel new file mode 100644 index 00000000..3336de04 --- /dev/null +++ b/experimental/pir/piano/BUILD.bazel @@ -0,0 +1,100 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +proto_library( + name = "piano_proto", + srcs = ["piano.proto"], +) + +cc_proto_library( + name = "piano_cc_proto", + deps = [":piano_proto"], +) + +psi_cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "@yacl//yacl/crypto/aes:aes_intrinsics", + ], +) + +psi_cc_library( + name = "serialize", + srcs = ["serialize.h"], + deps = [ + ":piano_cc_proto", + ":util", + "@yacl//yacl/base:buffer", + ], +) + +psi_cc_library( + name = "server", + srcs = ["server.cc"], + hdrs = ["server.h"], + deps = [ + ":piano_cc_proto", + ":serialize", + ":util", + "@yacl//yacl/link:context", + ], +) + +psi_cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + ":piano_cc_proto", + ":serialize", + ":util", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/link:context", + ], +) + +psi_cc_test( + name = "piano_test", + timeout = "eternal", + srcs = ["piano_test.cc"], + deps = [ + ":client", + ":server", + ":util", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/link:context", + "@yacl//yacl/link:test_util", + ], +) + +psi_cc_binary( + name = "piano_benchmark", + srcs = ["piano_benchmark.cc"], + deps = [ + ":client", + ":server", + ":util", + "@com_github_google_benchmark//:benchmark_main", + "@yacl//yacl/link:context", + "@yacl//yacl/link:test_util", + ], +) diff --git a/experimental/pir/piano/README.md b/experimental/pir/piano/README.md new file mode 100644 index 00000000..65cc6d99 --- /dev/null +++ b/experimental/pir/piano/README.md @@ -0,0 +1,28 @@ +Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation + +论文地址:https://eprint.iacr.org/2023/452 + +论文开源实现:https://github.com/wuwuz/Piano-PIR-new + +**方案概括** + +1. 采用特定于客户端的预处理模型(也称为订阅模型),让每个客户端在预处理期间下载并存储来自服务器的“提示” + +2. 实现了O(√n)的客户端存储,和O(√n)的在线通信与计算开销(平均到每个查询上) +3. 服务器是诚实且好奇的,恶意服务器也无法侵害隐私,但会导致查询结果出错 +4. 方案包括两个阶段:预处理阶段和在线查询阶段 + +**预处理阶段** + +1. 服务器将数据库划分为O(√n)个块,客户端**流式**的从服务器获取每块数据,并每次只处理当前块中的元素,包括记录部分数据和计算奇偶校验位 +2. 客户端要存储的“提示”包括三类:主表,替换条目和备份表,主表共有O(√n)个,替换条目和备份表在每个块上要存储 + O(1)个 + +**在线查询阶段** +1. 客户端查询包含x的主表,将主表中的x替换为替换条目中该块的下一个可用元素,发送序列给服务器。服务器计算序列的奇偶校验位并返回,客户端在本地通过异或操作恢复出DB[x] +2. 主表使用一次后就要丢弃,使用备份表进行替换,同时为了负载均衡,要保留(x,DB[x]) + +**具体实现** +1. 客户端不存储完整的序列,而是只存储tag,通过msk和PRF扩展出完整的序列 +2. 每个块都有O(1)个备份表,备份表中该块对应的DB[x]没有参与计算奇偶校验位,以实现快速替换 +3. 本地保存查询记录,当有重复查询时在本地查询,并发送随机序列给服务器 diff --git a/experimental/pir/piano/client.cc b/experimental/pir/piano/client.cc new file mode 100644 index 00000000..077df889 --- /dev/null +++ b/experimental/pir/piano/client.cc @@ -0,0 +1,348 @@ +#include "experimental/pir/piano/client.h" + +namespace pir::piano { + +uint64_t primaryNumParam(const double q, const double chunk_size, + const double target) { + const double k = std::ceil((std::log(2) * target) + std::log(q)); + return static_cast(k) * static_cast(chunk_size); +} + +double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num, + const uint64_t bin_size) { + const double mean = + static_cast(ball_num) / static_cast(bin_num); + const double c = (static_cast(bin_size) / mean) - 1; + // Chernoff bound exp(-(c^2)/(2+c) * mean) + double t = (mean * (c * c) / (2 + c)) * std::log(2); + t -= std::log2(static_cast(bin_num)); + return t; +} + +QueryServiceClient::QueryServiceClient( + const uint64_t db_size, const uint64_t thread_num, + std::shared_ptr context) + : db_size_(db_size), thread_num_(thread_num), context_(std::move(context)) { + Initialize(); + InitializeLocalSets(); +} + +void QueryServiceClient::Initialize() { + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + + master_key_ = RandKey(rng); + long_key_ = GetLongKey(&master_key_); + + // Q = sqrt(n) * ln(n) + totalQueryNum = + static_cast(std::sqrt(static_cast(db_size_)) * + std::log(static_cast(db_size_))); + + std::tie(chunk_size_, set_size_) = GenParams(db_size_); + + primary_set_num_ = + primaryNumParam(static_cast(totalQueryNum), + static_cast(chunk_size_), FailureProbLog2 + 1); + // if localSetNum is not a multiple of thread_num_ then we need to add some + // padding + primary_set_num_ = + (primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_; + + backup_set_num_per_chunk_ = + 3 * static_cast(static_cast(totalQueryNum) / + static_cast(set_size_)); + backup_set_num_per_chunk_ = + (backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; + + // set_size == chunk_number + total_backup_set_num_ = backup_set_num_per_chunk_ * set_size_; +} + +void QueryServiceClient::InitializeLocalSets() { + primary_sets_.clear(); + primary_sets_.reserve(primary_set_num_); + local_backup_sets_.clear(); + local_backup_sets_.reserve(total_backup_set_num_); + local_cache_.clear(); + local_miss_elements_.clear(); + uint32_t tagCounter = 0; + + for (uint64_t j = 0; j < primary_set_num_; j++) { + primary_sets_.emplace_back(tagCounter, ZeroEntry(), 0, false); + tagCounter += 1; + } + + local_backup_set_groups_.clear(); + local_backup_set_groups_.reserve(set_size_); + local_replacement_groups_.clear(); + local_replacement_groups_.reserve(set_size_); + + for (uint64_t i = 0; i < set_size_; i++) { + std::vector> backupSets; + for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) { + backupSets.emplace_back( + local_backup_sets_[(i * backup_set_num_per_chunk_) + j]); + } + LocalBackupSetGroup backupGroup(0, backupSets); + local_backup_set_groups_.emplace_back(std::move(backupGroup)); + + std::vector indices(backup_set_num_per_chunk_); + std::vector values(backup_set_num_per_chunk_); + LocalReplacementGroup replacementGroup(0, indices, values); + local_replacement_groups_.emplace_back(std::move(replacementGroup)); + } + + for (uint64_t j = 0; j < set_size_; j++) { + for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { + local_backup_set_groups_[j].sets[k].get() = + LocalBackupSet{tagCounter, ZeroEntry()}; + tagCounter += 1; + } + } +} + +void QueryServiceClient::FetchFullDB() { + const auto fetchFullDBMsg = SerializeFetchFullDBMsg(1); + context_->SendAsync(context_->NextRank(), fetchFullDBMsg, "FetchFullDBMsg"); + + for (uint64_t i = 0; i < set_size_; i++) { + auto chunkBuf = context_->Recv(context_->NextRank(), "DBChunk"); + if (chunkBuf.size() == 0) { + break; + } + auto dbChunk = DeserializeDBChunk(chunkBuf); + auto& chunk = std::get<2>(dbChunk); + + std::vector hitMap(chunk_size_, false); + + // Use multiple threads to parallelize the computation for the chunk + std::vector threads; + std::mutex hitMapMutex; + + // make sure all sets are covered + const uint64_t perTheadSetNum = + ((primary_set_num_ + thread_num_ - 1) / thread_num_) + 1; + const uint64_t perThreadBackupNum = + ((total_backup_set_num_ + thread_num_ - 1) / thread_num_) + 1; + + for (uint64_t tid = 0; tid < thread_num_; tid++) { + uint64_t startIndex = tid * perTheadSetNum; + uint64_t endIndex = + std::min(startIndex + perTheadSetNum, primary_set_num_); + + uint64_t startIndexBackup = tid * perThreadBackupNum; + uint64_t endIndexBackup = std::min(startIndexBackup + perThreadBackupNum, + total_backup_set_num_); + + threads.emplace_back([&, startIndex, endIndex, startIndexBackup, + endIndexBackup] { + // update the parities for the primary hints + for (uint64_t j = startIndex; j < endIndex; j++) { + const auto tmp = + PRFEvalWithLongKeyAndTag(long_key_, primary_sets_[j].tag, i); + const auto offset = tmp & (chunk_size_ - 1); + { + std::lock_guard lock(hitMapMutex); + hitMap[offset] = true; + } + DBEntryXorFromRaw(&primary_sets_[j].parity, + &chunk[offset * DBEntryLength]); + } + + // update the parities for the backup hints + for (uint64_t j = startIndexBackup; j < endIndexBackup; j++) { + const auto tmp = + PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); + const auto offset = tmp & (chunk_size_ - 1); + DBEntryXorFromRaw(&local_backup_sets_[j].parityAfterPunct, + &chunk[offset * DBEntryLength]); + } + }); + } + + for (auto& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } + + // If any element is not hit, then it is a local miss. We will save it in + // the local miss cache. Most of the time, the local miss cache will be + // empty. + for (uint64_t j = 0; j < chunk_size_; j++) { + if (!hitMap[j]) { + std::array entry_slice{}; + std::memcpy(entry_slice.data(), &chunk[j * DBEntryLength], + DBEntryLength * sizeof(uint64_t)); + const auto entry = DBEntryFromSlice(entry_slice); + local_miss_elements_[j + (i * chunk_size_)] = entry; + } + } + + // For the i-th group of backups, leave the i-th chunk as blank + // To do that, we just xor the i-th chunk's value again + for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { + const auto tag = local_backup_set_groups_[i].sets[k].get().tag; + const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); + const auto offset = tmp & (chunk_size_ - 1); + DBEntryXorFromRaw( + &local_backup_set_groups_[i].sets[k].get().parityAfterPunct, + &chunk[offset * DBEntryLength]); + } + + // store the replacement + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { + // generate a random offset between 0 and ChunkSize - 1 + const auto offset = rng() & (chunk_size_ - 1); + local_replacement_groups_[i].indices[k] = offset + i * chunk_size_; + std::array entry_slice{}; + std::memcpy(entry_slice.data(), &chunk[offset * DBEntryLength], + DBEntryLength * sizeof(uint64_t)); + local_replacement_groups_[i].value[k] = DBEntryFromSlice(entry_slice); + } + } +} + +void QueryServiceClient::SendDummySet() const { + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + std::vector randSet(set_size_); + for (uint64_t i = 0; i < set_size_; i++) { + randSet[i] = rng() % chunk_size_ + i * chunk_size_; + } + + // send the random dummy set to the server + const auto query_msg = SerializeSetParityQueryMsg(set_size_, randSet); + context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); + + const auto response_buf = + context_->Recv(context_->NextRank(), "SetParityQueryResponse"); + // auto parityQueryResponse = DeserializeSetParityQueryResponse(response_buf); +} + +DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { + // make sure x is not in the local cache + if (local_cache_.find(x) != local_cache_.end()) { + SendDummySet(); + return local_cache_[x]; + } + + // 1. Query x: the client first finds a local set that contains x + // 2. The client expands the set, replace the chunk(x)-th element to a + // replacement + // 3. The client sends the edited set to the server and gets the parity + // 4. The client recovers the answer + uint64_t hitSetId = std::numeric_limits::max(); + + const uint64_t queryOffset = x % chunk_size_; + const uint64_t chunkId = x / chunk_size_; + + for (uint64_t i = 0; i < primary_set_num_; i++) { + const auto& set = primary_sets_[i]; + if (const bool isProgrammedMatch = + set.isProgrammed && chunkId == (set.programmedPoint / chunk_size_); + !isProgrammedMatch && + PRSetWithShortTag{set.tag}.MemberTestWithLongKeyAndTag( + long_key_, chunkId, queryOffset, chunk_size_)) { + hitSetId = i; + break; + } + } + + DBEntry xVal = ZeroEntry(); + + if (hitSetId == std::numeric_limits::max()) { + if (local_miss_elements_.find(x) == local_miss_elements_.end()) { + SPDLOG_ERROR("No hit set found for %lu", x); + } else { + xVal = local_miss_elements_[x]; + local_cache_[x] = xVal; + } + + SendDummySet(); + return xVal; + } + + // expand the set + const PRSetWithShortTag set{primary_sets_[hitSetId].tag}; + auto expandedSet = set.ExpandWithLongKey(long_key_, set_size_, chunk_size_); + + // manually program the set if the flag is set before + if (primary_sets_[hitSetId].isProgrammed) { + const uint64_t programmedChunkId = + primary_sets_[hitSetId].programmedPoint / chunk_size_; + expandedSet[programmedChunkId] = primary_sets_[hitSetId].programmedPoint; + } + + // edit the set by replacing the chunk(x)-th element with a replacement + const uint64_t nxtAvailable = local_replacement_groups_[chunkId].consumed; + if (nxtAvailable == backup_set_num_per_chunk_) { + SPDLOG_ERROR("No replacement available for %lu", x); + SendDummySet(); + return xVal; + } + + // consume one replacement + const uint64_t repIndex = + local_replacement_groups_[chunkId].indices[nxtAvailable]; + const DBEntry repVal = local_replacement_groups_[chunkId].value[nxtAvailable]; + local_replacement_groups_[chunkId].consumed++; + expandedSet[chunkId] = repIndex; + + // send the edited set to the server + const auto query_msg = SerializeSetParityQueryMsg(set_size_, expandedSet); + context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); + + const auto response_buf = + context_->Recv(context_->NextRank(), "SetParityQueryResponse"); + + const auto parityQueryResponse = + DeserializeSetParityQueryResponse(response_buf); + const auto& parity = std::get<0>(parityQueryResponse); + + // recover the answer + xVal = primary_sets_[hitSetId].parity; // the parity of the hit set + DBEntryXorFromRaw(&xVal, parity.data()); // xor the parity of the edited set + DBEntryXor(&xVal, &repVal); // xor the replacement value + + // update the local cache + local_cache_[x] = xVal; + + // refresh phase + if (local_backup_set_groups_[chunkId].consumed == backup_set_num_per_chunk_) { + SPDLOG_WARN("No backup set available for %lu", x); + return xVal; + } + + const DBEntry originalXVal = xVal; + const uint64_t consumed = local_backup_set_groups_[chunkId].consumed; + primary_sets_[hitSetId].tag = + local_backup_set_groups_[chunkId].sets[consumed].get().tag; + // backup set doesn't XOR the chunk(x)-th element in preparation + DBEntryXor( + &xVal, + &local_backup_set_groups_[chunkId].sets[consumed].get().parityAfterPunct); + primary_sets_[hitSetId].parity = xVal; + primary_sets_[hitSetId].isProgrammed = true; + // for load balancing, the chunk(x)-th element differs from the one expanded + // via PRFEval on the tag + primary_sets_[hitSetId].programmedPoint = x; + local_backup_set_groups_[chunkId].consumed++; + + return originalXVal; +} + +std::vector QueryServiceClient::OnlineMultipleQueries( + const std::vector& queries) { + std::vector results; + results.reserve(queries.size()); + + for (const auto& x : queries) { + DBEntry result = OnlineSingleQuery(x); + results.push_back(result); + } + + return results; +} + +} diff --git a/experimental/pir/piano/client.h b/experimental/pir/piano/client.h new file mode 100644 index 00000000..9092f3b2 --- /dev/null +++ b/experimental/pir/piano/client.h @@ -0,0 +1,104 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "yacl/crypto/rand/rand.h" +#include "yacl/link/context.h" + +#include "experimental/pir/piano/serialize.h" +#include "experimental/pir/piano/util.h" + +namespace pir::piano { + +class LocalSet { + public: + uint32_t tag; // the tag of the set + DBEntry parity; + uint64_t + programmedPoint; // identifier for the element replaced after refresh, + // differing from those expanded by PRFEval + bool isProgrammed; + + LocalSet(const uint32_t tag, const DBEntry& parity, + const uint64_t programmed_point, const bool is_programmed) + : tag(tag), + parity(parity), + programmedPoint(programmed_point), + isProgrammed(is_programmed) {} +}; + +class LocalBackupSet { + public: + uint32_t tag; + DBEntry parityAfterPunct; + + LocalBackupSet(const uint32_t tag, const DBEntry& parity_after_punct) + : tag(tag), parityAfterPunct(parity_after_punct) {} +}; + +class LocalBackupSetGroup { + public: + uint64_t consumed; + std::vector> sets; + + LocalBackupSetGroup( + const uint64_t consumed, + const std::vector>& sets) + : consumed(consumed), sets(sets) {} +}; + +class LocalReplacementGroup { + public: + uint64_t consumed; + std::vector indices; + std::vector value; + + LocalReplacementGroup(const uint64_t consumed, + const std::vector& indices, + const std::vector& value) + : consumed(consumed), indices(indices), value(value) {} +}; + +class QueryServiceClient { + public: + static constexpr uint64_t FailureProbLog2 = 40; + uint64_t totalQueryNum{}; + + QueryServiceClient(uint64_t db_size, uint64_t thread_num, + std::shared_ptr context); + + void Initialize(); + void InitializeLocalSets(); + void FetchFullDB(); + void SendDummySet() const; + DBEntry OnlineSingleQuery(uint64_t x); + std::vector OnlineMultipleQueries( + const std::vector& queries); + + private: + uint64_t db_size_; + uint64_t thread_num_; + std::shared_ptr context_; + + uint64_t chunk_size_{}; + uint64_t set_size_{}; + uint64_t primary_set_num_{}; + uint64_t backup_set_num_per_chunk_{}; + uint64_t total_backup_set_num_{}; + PrfKey master_key_{}; + yacl::crypto::AES_KEY long_key_{}; + + std::vector primary_sets_; + std::vector local_backup_sets_; + std::map local_cache_; + std::map local_miss_elements_; + std::vector local_backup_set_groups_; + std::vector local_replacement_groups_; +}; + +} diff --git a/experimental/pir/piano/piano.proto b/experimental/pir/piano/piano.proto new file mode 100644 index 00000000..514db0e6 --- /dev/null +++ b/experimental/pir/piano/piano.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package pir.piano; + +message FetchFullDbMsg { + uint64 dummy = 1; +} + +message DbChunk { + uint64 chunk_id = 1; + uint64 chunk_size = 2; + repeated uint64 chunks = 3; +} + +message SetParityQueryMsg { + uint64 set_size = 1; + repeated uint64 indices = 2; +} + +message SetParityQueryResponse { + repeated uint64 parity = 1; + uint64 server_compute_time = 2; +} + +message QueryRequest { + oneof request { + FetchFullDbMsg fetch_full_db = 1; + SetParityQueryMsg set_parity_query = 2; + } +} diff --git a/experimental/pir/piano/piano_benchmark.cc b/experimental/pir/piano/piano_benchmark.cc new file mode 100644 index 00000000..d82741b7 --- /dev/null +++ b/experimental/pir/piano/piano_benchmark.cc @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "yacl/link/context.h" +#include "yacl/link/test_util.h" + +#include "experimental/pir/piano/client.h" +#include "experimental/pir/piano/server.h" +#include "experimental/pir/piano/util.h" + +namespace { + +std::vector GenerateQueries(const uint64_t query_num, + const uint64_t db_size) { + std::vector queries; + queries.reserve(query_num); + + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + for (uint64_t q = 0; q < query_num; ++q) { + queries.push_back(rng() % db_size); + } + + return queries; +} + +std::vector CreateDatabase(const uint64_t db_size, + const uint64_t db_seed) { + const auto [ChunkSize, SetSize] = pir::piano::GenParams(db_size); + std::vector DB; + DB.assign(ChunkSize * SetSize * pir::piano::DBEntryLength, 0); + + for (uint64_t i = 0; i < DB.size() / pir::piano::DBEntryLength; ++i) { + auto entry = pir::piano::GenDBEntry(db_seed, i); + std::memcpy(&DB[i * pir::piano::DBEntryLength], entry.data(), + pir::piano::DBEntryLength * sizeof(uint64_t)); + } + + return DB; +} + +void SetupAndRunServer( + const std::shared_ptr& server_context, + const uint64_t db_size, std::promise& exit_signal, + std::vector& db) { + const auto [ChunkSize, SetSize] = pir::piano::GenParams(db_size); + pir::piano::QueryServiceServer server(db, server_context, SetSize, ChunkSize); + server.Start(exit_signal.get_future()); +} + +std::vector SetupAndRunClient( + const uint64_t db_size, const uint64_t thread_num, + const std::shared_ptr& client_context, + const std::vector& queries) { + pir::piano::QueryServiceClient client(db_size, thread_num, client_context); + client.FetchFullDB(); + return client.OnlineMultipleQueries(queries); +} + +} // namespace + +static void BM_PianoPir(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + uint64_t db_size = state.range(0) / sizeof(pir::piano::DBEntry); + const uint64_t query_num = state.range(1); + constexpr uint64_t db_seed = 2315127; + uint64_t thread_num = 8; + + constexpr int kWorldSize = 2; + const auto contexts = yacl::link::test::SetupWorld(kWorldSize); + yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); + + auto db = CreateDatabase(db_size, db_seed); + auto queries = GenerateQueries(query_num, db_size); + + state.ResumeTiming(); + std::promise exitSignal; + auto server_future = + std::async(std::launch::async, SetupAndRunServer, contexts[0], db_size, + std::ref(exitSignal), std::ref(db)); + + auto client_future = + std::async(std::launch::async, SetupAndRunClient, db_size, thread_num, + contexts[1], std::cref(queries)); + auto results = client_future.get(); + + exitSignal.set_value(); + server_future.get(); + } +} + +// [1m, 16m, 64m, 128m] +BENCHMARK(BM_PianoPir) + ->Unit(benchmark::kMillisecond) + ->Args({1 << 20, 1000}) + ->Args({16 << 20, 1000}) + ->Args({64 << 20, 1000}) + ->Args({128 << 20, 1000}); diff --git a/experimental/pir/piano/piano_test.cc b/experimental/pir/piano/piano_test.cc new file mode 100644 index 00000000..3381151f --- /dev/null +++ b/experimental/pir/piano/piano_test.cc @@ -0,0 +1,120 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "yacl/link/context.h" +#include "yacl/link/test_util.h" + +#include "experimental/pir/piano/client.h" +#include "experimental/pir/piano/serialize.h" +#include "experimental/pir/piano/server.h" +#include "experimental/pir/piano/util.h" + +struct TestParams { + uint64_t db_size; + uint64_t db_seed; + uint64_t thread_num; + uint64_t query_num; + bool is_total_query_num; +}; + +namespace pir::piano { + +std::vector GenerateQueries(const uint64_t query_num, + const uint64_t db_size) { + std::vector queries; + queries.reserve(query_num); + + std::mt19937_64 rng(yacl::crypto::FastRandU64()); + for (uint64_t q = 0; q < query_num; ++q) { + queries.push_back(rng() % db_size); + } + + return queries; +} + +std::vector RunClient(QueryServiceClient& client, + const std::vector& queries) { + client.FetchFullDB(); + return client.OnlineMultipleQueries(queries); +} + +std::vector getResults(const std::vector& queries, + const TestParams& params) { + std::vector expected_results; + expected_results.reserve(queries.size()); + + for (const auto& x : queries) { + expected_results.push_back(GenDBEntry(params.db_seed, x)); + } + + return expected_results; +} + +class PianoTest : public testing::TestWithParam {}; + +TEST_P(PianoTest, Works) { + auto params = GetParam(); + constexpr int kWorldSize = 2; + const auto contexts = yacl::link::test::SetupWorld(kWorldSize); + + SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", + params.db_size, DBEntrySize, + params.db_size * DBEntrySize / 1024 / 1024); + + auto [ChunkSize, SetSize] = GenParams(params.db_size); + SPDLOG_INFO("Chunk Size: %lu, Set Size: %lu\n", ChunkSize, SetSize); + + std::vector DB; + DB.assign(ChunkSize * SetSize * DBEntryLength, 0); + SPDLOG_INFO("DB Real N: %lu\n", DB.size()); + + for (uint64_t i = 0; i < DB.size() / DBEntryLength; ++i) { + auto entry = GenDBEntry(params.db_seed, i); + std::memcpy(&DB[i * DBEntryLength], entry.data(), + DBEntryLength * sizeof(uint64_t)); + } + + QueryServiceClient client(params.db_size, params.thread_num, contexts[1]); + + const auto actual_query_num = + params.is_total_query_num ? client.totalQueryNum : params.query_num; + auto queries = GenerateQueries(actual_query_num, DB.size()); + + yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); + QueryServiceServer server(DB, contexts[0], SetSize, ChunkSize); + + std::promise exitSignal; + std::future futureObj = exitSignal.get_future(); + auto server_future = + std::async(std::launch::async, &QueryServiceServer::Start, + std::ref(server), std::move(futureObj)); + auto client_future = std::async(std::launch::async, RunClient, + std::ref(client), std::cref(queries)); + + auto results = client_future.get(); + auto expected_results = getResults(queries, params); + + for (size_t i = 0; i < results.size(); ++i) { + EXPECT_EQ(results[i], expected_results[i]) + << "Mismatch at index " << queries[i]; + } + + exitSignal.set_value(); + server_future.get(); +} + +// [8m, 128m, 256m] +INSTANTIATE_TEST_SUITE_P( + PianoTestInstances, PianoTest, + ::testing::Values(TestParams{131072, 1211212, 8, 1000, false}, + TestParams{2097152, 6405285, 8, 1000, false}, + TestParams{4194304, 7539870, 16, 1000, false})); +} diff --git a/experimental/pir/piano/serialize.h b/experimental/pir/piano/serialize.h new file mode 100644 index 00000000..451272aa --- /dev/null +++ b/experimental/pir/piano/serialize.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include + +#include "yacl/base/buffer.h" + +#include "experimental/pir/piano/util.h" + +#include "experimental/pir/piano/piano.pb.h" + +namespace pir::piano { + +inline yacl::Buffer SerializeFetchFullDBMsg(const uint64_t dummy) { + QueryRequest proto; + FetchFullDbMsg* fetch_full_db_msg = proto.mutable_fetch_full_db(); + fetch_full_db_msg->set_dummy(dummy); + + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + + return buf; +} + +inline uint64_t DeserializeFetchFullDBMsg(const yacl::Buffer& buf) { + QueryRequest proto; + proto.ParseFromArray(buf.data(), buf.size()); + return proto.fetch_full_db().dummy(); +} + +inline yacl::Buffer SerializeDBChunk(const uint64_t chunk_id, + const uint64_t chunk_size, + const std::vector& chunk) { + DbChunk proto; + proto.set_chunk_id(chunk_id); + proto.set_chunk_size(chunk_size); + for (const auto& val : chunk) { + proto.add_chunks(val); + } + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + return buf; +} + +inline std::tuple> DeserializeDBChunk( + const yacl::Buffer& buf) { + DbChunk proto; + proto.ParseFromArray(buf.data(), buf.size()); + std::vector chunk(proto.chunks().begin(), proto.chunks().end()); + return {proto.chunk_id(), proto.chunk_size(), chunk}; +} + +inline yacl::Buffer SerializeSetParityQueryMsg( + const uint64_t set_size, const std::vector& indices) { + QueryRequest proto; + SetParityQueryMsg* set_parity_query = proto.mutable_set_parity_query(); + set_parity_query->set_set_size(set_size); + for (const auto& index : indices) { + set_parity_query->add_indices(index); + } + + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + + return buf; +} + +inline std::pair> DeserializeSetParityQueryMsg( + const yacl::Buffer& buf) { + QueryRequest proto; + proto.ParseFromArray(buf.data(), buf.size()); + const auto& set_parity_query = proto.set_parity_query(); + std::vector indices(set_parity_query.indices().begin(), + set_parity_query.indices().end()); + return {set_parity_query.set_size(), indices}; +} + +inline yacl::Buffer SerializeSetParityQueryResponse( + const std::vector& parity, const uint64_t server_compute_time) { + SetParityQueryResponse proto; + for (const auto& p : parity) { + proto.add_parity(p); + } + proto.set_server_compute_time(server_compute_time); + yacl::Buffer buf(proto.ByteSizeLong()); + proto.SerializeToArray(buf.data(), buf.size()); + return buf; +} + +inline std::pair, uint64_t> +DeserializeSetParityQueryResponse(const yacl::Buffer& buf) { + SetParityQueryResponse proto; + proto.ParseFromArray(buf.data(), buf.size()); + std::vector parity(proto.parity().begin(), proto.parity().end()); + return {parity, proto.server_compute_time()}; +} + +} diff --git a/experimental/pir/piano/server.cc b/experimental/pir/piano/server.cc new file mode 100644 index 00000000..82a9a63b --- /dev/null +++ b/experimental/pir/piano/server.cc @@ -0,0 +1,104 @@ +#include "experimental/pir/piano/server.h" + +namespace pir::piano { + +QueryServiceServer::QueryServiceServer( + std::vector& db, std::shared_ptr context, + const uint64_t set_size, const uint64_t chunk_size) + : db_(std::move(db)), + context_(std::move(context)), + set_size_(set_size), + chunk_size_(chunk_size) {} + +void QueryServiceServer::Start(const std::future& stop_signal) { + while (stop_signal.wait_for(std::chrono::milliseconds(1)) == + std::future_status::timeout) { + auto request_data = context_->Recv(context_->NextRank(), "request_data"); + HandleRequest(request_data); + } +} + +void QueryServiceServer::HandleRequest(const yacl::Buffer& request_data) { + QueryRequest proto; + proto.ParseFromArray(request_data.data(), request_data.size()); + + switch (proto.request_case()) { + case QueryRequest::kFetchFullDb: { + // uint64_t dummy = DeserializeFetchFullDBMsg(request_data); + ProcessFetchFullDB(); + break; + } + case QueryRequest::kSetParityQuery: { + const auto parityQuery = DeserializeSetParityQueryMsg(request_data); + const auto& indices = std::get<1>(parityQuery); + + auto [parity, server_compute_time] = ProcessSetParityQuery(indices); + const auto response_buf = + SerializeSetParityQueryResponse(parity, server_compute_time); + context_->SendAsync(context_->NextRank(), response_buf, + "SetParityQueryResponse"); + break; + } + default: + SPDLOG_ERROR("Unknown request type."); + } +} + +void QueryServiceServer::ProcessFetchFullDB() { + for (uint64_t i = 0; i < set_size_; ++i) { + const uint64_t down = i * chunk_size_; + uint64_t up = (i + 1) * chunk_size_; + up = std::min(up, static_cast(db_.size())); + std::vector chunk(db_.begin() + down * DBEntryLength, + db_.begin() + up * DBEntryLength); + auto chunk_buf = SerializeDBChunk(i, chunk.size(), chunk); + + try { + context_->SendAsync(context_->NextRank(), chunk_buf, "FetchFullDBChunk"); + } catch (const std::exception& e) { + SPDLOG_ERROR("Failed to send a chunk."); + return; + } + } +} + +std::pair, uint64_t> +QueryServiceServer::ProcessSetParityQuery( + const std::vector& indices) { + const auto start = std::chrono::high_resolution_clock::now(); + std::vector parity = HandleSetParityQuery(indices); + const auto end = std::chrono::high_resolution_clock::now(); + const auto duration = + std::chrono::duration_cast(end - start).count(); + return {parity, duration}; +} + +DBEntry QueryServiceServer::DBAccess(const uint64_t id) { + if (id < db_.size()) { + if (id * DBEntryLength + DBEntryLength > db_.size()) { + SPDLOG_ERROR("DBAccess: id {} out of range", id); + } + std::array slice{}; + std::copy(db_.begin() + id * DBEntryLength, + db_.begin() + (id + 1) * DBEntryLength, slice.begin()); + return DBEntryFromSlice(slice); + } + DBEntry ret; + ret.fill(0); + return ret; +} + +std::vector QueryServiceServer::HandleSetParityQuery( + const std::vector& indices) { + DBEntry parity = ZeroEntry(); + for (const auto& index : indices) { + DBEntry entry = DBAccess(index); + DBEntryXor(&parity, &entry); + } + + std::vector ret(DBEntryLength); + std::copy(parity.begin(), parity.end(), ret.begin()); + return ret; +} + +} diff --git a/experimental/pir/piano/server.h b/experimental/pir/piano/server.h new file mode 100644 index 00000000..04a1d2f3 --- /dev/null +++ b/experimental/pir/piano/server.h @@ -0,0 +1,54 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "yacl/link/context.h" + +#include "experimental/pir/piano/serialize.h" +#include "experimental/pir/piano/util.h" + +namespace pir::piano { + +class QueryServiceServer { + public: + // Constructor: initializes the server with a database, context, set_size, and + // chunk_size + QueryServiceServer(std::vector& db, + std::shared_ptr context, + uint64_t set_size, uint64_t chunk_size); + + // Starts the server to handle incoming requests + void Start(const std::future& stop_signal); + + // Handles the incoming request based on its type + void HandleRequest(const yacl::Buffer& request_data); + + // Processes a request to fetch the full database + void ProcessFetchFullDB(); + + // Processes a set parity query and returns the parity and server compute time + std::pair, uint64_t> ProcessSetParityQuery( + const std::vector& indices); + + private: + // Accesses the database and returns the corresponding entry + DBEntry DBAccess(uint64_t id); + + // Handles a set parity query and returns the parity + std::vector HandleSetParityQuery( + const std::vector& indices); + + std::vector db_; // The database + std::shared_ptr context_; // The communication context + uint64_t set_size_; // The size of the set + uint64_t chunk_size_; // The size of each chunk +}; + +} diff --git a/experimental/pir/piano/util.cc b/experimental/pir/piano/util.cc new file mode 100644 index 00000000..f81b9e26 --- /dev/null +++ b/experimental/pir/piano/util.cc @@ -0,0 +1,169 @@ +#include "experimental/pir/piano/util.h" + +namespace pir::piano { + +uint128_t BytesToUint128(const std::string& bytes) { + if (bytes.size() != 16) { + SPDLOG_WARN("Bytes size must be 16 for uint128_t conversion."); + } + + uint128_t result = 0; + std::memcpy(&result, bytes.data(), 16); + return result; +} + +std::string Uint128ToBytes(const uint128_t value) { + std::string bytes(16, 0); + std::memcpy(bytes.data(), &value, 16); + return bytes; +} + +PrfKey128 RandKey128(std::mt19937_64& rng) { + const uint64_t lo = rng(); + const uint64_t hi = rng(); + return yacl::MakeUint128(hi, lo); +} + +PrfKey RandKey(std::mt19937_64& rng) { return RandKey128(rng); } + +uint64_t PRFEval128(const PrfKey128* key, const uint64_t x) { + yacl::crypto::AES_KEY aes_key; + AES_set_encrypt_key(*key, &aes_key); + + const auto src_block = static_cast(x); + std::vector plain_blocks(1); + plain_blocks[0] = src_block; + std::vector cipher_blocks(1); + + AES_ecb_encrypt_blks(aes_key, absl::MakeConstSpan(plain_blocks), + absl::MakeSpan(cipher_blocks)); + return static_cast(cipher_blocks[0]); +} + +uint64_t PRFEval(const PrfKey* key, const uint64_t x) { + return PRFEval128(key, x); +} + +void DBEntryXor(DBEntry* dst, const DBEntry* src) { + for (size_t i = 0; i < DBEntryLength; ++i) { + (*dst)[i] ^= (*src)[i]; + } +} + +void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src) { + for (size_t i = 0; i < DBEntryLength; ++i) { + (*dst)[i] ^= src[i]; + } +} + +bool EntryIsEqual(const DBEntry& a, const DBEntry& b) { + for (size_t i = 0; i < DBEntryLength; ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; +} + +DBEntry RandDBEntry(std::mt19937_64& rng) { + DBEntry entry; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = rng(); + } + return entry; +} + +uint64_t DefaultHash(uint64_t key) { + constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; + uint64_t hash = FNV_offset_basis; + for (int i = 0; i < 8; ++i) { + constexpr uint64_t FNV_prime = 1099511628211ULL; + const auto byte = static_cast(key & 0xFF); + hash ^= static_cast(byte); + hash *= FNV_prime; + key >>= 8; + } + return hash; +} + +DBEntry GenDBEntry(const uint64_t key, const uint64_t id) { + DBEntry entry; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = DefaultHash((key ^ id) + i); + } + return entry; +} + +DBEntry ZeroEntry() { + DBEntry entry = {}; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = 0; + } + return entry; +} + +DBEntry DBEntryFromSlice(const std::array& s) { + DBEntry entry; + for (size_t i = 0; i < DBEntryLength; ++i) { + entry[i] = s[i]; + } + return entry; +} + +// Generate ChunkSize and SetSize +std::pair GenParams(const uint64_t db_size) { + const double targetChunkSize = 2 * std::sqrt(static_cast(db_size)); + uint64_t ChunkSize = 1; + + // Ensure ChunkSize is a power of 2 and not smaller than targetChunkSize + while (ChunkSize < static_cast(targetChunkSize)) { + ChunkSize *= 2; + } + + uint64_t SetSize = (db_size + ChunkSize - 1) / ChunkSize; + // Round up to the next multiple of 4 + SetSize = (SetSize + 3) / 4 * 4; + + return {ChunkSize, SetSize}; +} + +yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key) { + yacl::crypto::AES_KEY aes_key; + AES_set_encrypt_key(*key, &aes_key); + return aes_key; +} + +uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, + const uint32_t tag, const uint64_t x) { + // Combine tag and x into a 128-bit block by shifting tag to the high 64 bits + const uint128_t src_block = (static_cast(tag) << 64) + x; + std::vector plain_blocks(1); + plain_blocks[0] = src_block; + std::vector cipher_blocks(1); + AES_ecb_encrypt_blks(long_key, absl::MakeConstSpan(plain_blocks), + absl::MakeSpan(cipher_blocks)); + return static_cast(cipher_blocks[0]); +} + +std::vector PRSetWithShortTag::ExpandWithLongKey( + const yacl::crypto::AES_KEY& long_key, const uint64_t set_size, + const uint64_t chunk_size) const { + std::vector expandedSet(set_size); + for (uint64_t i = 0; i < set_size; i++) { + const uint64_t tmp = PRFEvalWithLongKeyAndTag(long_key, Tag, i); + // Get the offset within the chunk + const uint64_t offset = tmp & (chunk_size - 1); + expandedSet[i] = i * chunk_size + offset; + } + return expandedSet; +} + +bool PRSetWithShortTag::MemberTestWithLongKeyAndTag( + const yacl::crypto::AES_KEY& long_key, const uint64_t chunk_id, + const uint64_t offset, const uint64_t chunk_size) const { + // Ensure chunk_size is a power of 2 and compare offsets + return offset == + (PRFEvalWithLongKeyAndTag(long_key, Tag, chunk_id) & (chunk_size - 1)); +} + +} diff --git a/experimental/pir/piano/util.h b/experimental/pir/piano/util.h new file mode 100644 index 00000000..da20e381 --- /dev/null +++ b/experimental/pir/piano/util.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "yacl/crypto/aes/aes_intrinsics.h" + +namespace pir::piano { + +constexpr size_t DBEntrySize = 8; // has to be a multiple of 8 +constexpr size_t DBEntryLength = DBEntrySize / 8; + +using PrfKey128 = uint128_t; +using DBEntry = std::array; +using PrfKey = PrfKey128; + +uint128_t BytesToUint128(const std::string& bytes); + +std::string Uint128ToBytes(uint128_t value); + +// Generates a random 128-bit key using the provided RNG +PrfKey128 RandKey128(std::mt19937_64& rng); + +// Generates a random PRF key +PrfKey RandKey(std::mt19937_64& rng); + +// Evaluates PRF using 128-bit key and returns a 64-bit result +uint64_t PRFEval128(const PrfKey128* key, uint64_t x); + +// Evaluates PRF using a general PrfKey and returns a 64-bit result +uint64_t PRFEval(const PrfKey* key, uint64_t x); + +// XOR two DBEntry structures +void DBEntryXor(DBEntry* dst, const DBEntry* src); + +// XOR a DBEntry with raw uint64_t data +void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src); + +// Compare two DBEntry structures for equality +bool EntryIsEqual(const DBEntry& a, const DBEntry& b); + +// Generate a random DBEntry using the provided RNG +DBEntry RandDBEntry(std::mt19937_64& rng); + +// Default FNV hash implementation for 64-bit keys +uint64_t DefaultHash(uint64_t key); + +// Generate a DBEntry based on a key and ID +DBEntry GenDBEntry(uint64_t key, uint64_t id); + +// Generate a zero-filled DBEntry +DBEntry ZeroEntry(); + +// Convert a slice (array) into a DBEntry structure +DBEntry DBEntryFromSlice(const std::array& s); + +// Generate parameters for ChunkSize and SetSize +std::pair GenParams(uint64_t db_size); + +// Returns a long key (AES expanded key) for PRF evaluation +yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key); + +// PRF evaluation with a long key and tag, returns a 64-bit result +uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, + uint32_t tag, uint64_t x); + +class PRSetWithShortTag { + public: + uint32_t Tag; + + // Expands the set with a long key and tag + [[nodiscard]] std::vector ExpandWithLongKey( + const yacl::crypto::AES_KEY& long_key, uint64_t set_size, + uint64_t chunk_size) const; + + // Membership test with a long key and tag, to check if an ID belongs to the + // set + [[nodiscard]] bool MemberTestWithLongKeyAndTag( + const yacl::crypto::AES_KEY& long_key, uint64_t chunk_id, uint64_t offset, + uint64_t chunk_size) const; +}; + +} From b25e91a61128b365ac421e2e254f1d89576cf6de Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Mon, 4 Nov 2024 11:53:58 +0800 Subject: [PATCH 06/11] Delete original file --- experimental/pir/piano/client.cc | 2 +- experimental/pir/piano/client.h | 7 +- experimental/pir/piano/piano_benchmark.cc | 5 +- experimental/pir/piano/piano_test.cc | 9 +- experimental/pir/piano/serialize.h | 5 +- experimental/pir/piano/server.cc | 2 +- experimental/pir/piano/server.h | 5 +- experimental/pir/piano/util.cc | 2 +- experimental/pir/piano/util.h | 2 +- psi/piano/BUILD.bazel | 100 ------- psi/piano/README.md | 28 -- psi/piano/client.cc | 348 ---------------------- psi/piano/client.h | 104 ------- psi/piano/piano.proto | 30 -- psi/piano/piano_benchmark.cc | 103 ------- psi/piano/piano_test.cc | 120 -------- psi/piano/serialize.h | 99 ------ psi/piano/server.cc | 104 ------- psi/piano/server.h | 54 ---- psi/piano/util.cc | 169 ----------- psi/piano/util.h | 88 ------ 21 files changed, 17 insertions(+), 1369 deletions(-) delete mode 100644 psi/piano/BUILD.bazel delete mode 100644 psi/piano/README.md delete mode 100644 psi/piano/client.cc delete mode 100644 psi/piano/client.h delete mode 100644 psi/piano/piano.proto delete mode 100644 psi/piano/piano_benchmark.cc delete mode 100644 psi/piano/piano_test.cc delete mode 100644 psi/piano/serialize.h delete mode 100644 psi/piano/server.cc delete mode 100644 psi/piano/server.h delete mode 100644 psi/piano/util.cc delete mode 100644 psi/piano/util.h diff --git a/experimental/pir/piano/client.cc b/experimental/pir/piano/client.cc index 077df889..7875f1c6 100644 --- a/experimental/pir/piano/client.cc +++ b/experimental/pir/piano/client.cc @@ -345,4 +345,4 @@ std::vector QueryServiceClient::OnlineMultipleQueries( return results; } -} +} // namespace pir::piano diff --git a/experimental/pir/piano/client.h b/experimental/pir/piano/client.h index 9092f3b2..973a13dc 100644 --- a/experimental/pir/piano/client.h +++ b/experimental/pir/piano/client.h @@ -7,11 +7,10 @@ #include #include -#include "yacl/crypto/rand/rand.h" -#include "yacl/link/context.h" - #include "experimental/pir/piano/serialize.h" #include "experimental/pir/piano/util.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/link/context.h" namespace pir::piano { @@ -101,4 +100,4 @@ class QueryServiceClient { std::vector local_replacement_groups_; }; -} +} // namespace pir::piano diff --git a/experimental/pir/piano/piano_benchmark.cc b/experimental/pir/piano/piano_benchmark.cc index d82741b7..53526f3c 100644 --- a/experimental/pir/piano/piano_benchmark.cc +++ b/experimental/pir/piano/piano_benchmark.cc @@ -6,12 +6,11 @@ #include #include "benchmark/benchmark.h" -#include "yacl/link/context.h" -#include "yacl/link/test_util.h" - #include "experimental/pir/piano/client.h" #include "experimental/pir/piano/server.h" #include "experimental/pir/piano/util.h" +#include "yacl/link/context.h" +#include "yacl/link/test_util.h" namespace { diff --git a/experimental/pir/piano/piano_test.cc b/experimental/pir/piano/piano_test.cc index 3381151f..f57b4d26 100644 --- a/experimental/pir/piano/piano_test.cc +++ b/experimental/pir/piano/piano_test.cc @@ -8,14 +8,13 @@ #include #include -#include "gtest/gtest.h" -#include "yacl/link/context.h" -#include "yacl/link/test_util.h" - #include "experimental/pir/piano/client.h" #include "experimental/pir/piano/serialize.h" #include "experimental/pir/piano/server.h" #include "experimental/pir/piano/util.h" +#include "gtest/gtest.h" +#include "yacl/link/context.h" +#include "yacl/link/test_util.h" struct TestParams { uint64_t db_size; @@ -117,4 +116,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(TestParams{131072, 1211212, 8, 1000, false}, TestParams{2097152, 6405285, 8, 1000, false}, TestParams{4194304, 7539870, 16, 1000, false})); -} +} // namespace pir::piano diff --git a/experimental/pir/piano/serialize.h b/experimental/pir/piano/serialize.h index 451272aa..28a7a537 100644 --- a/experimental/pir/piano/serialize.h +++ b/experimental/pir/piano/serialize.h @@ -4,9 +4,8 @@ #include #include -#include "yacl/base/buffer.h" - #include "experimental/pir/piano/util.h" +#include "yacl/base/buffer.h" #include "experimental/pir/piano/piano.pb.h" @@ -96,4 +95,4 @@ DeserializeSetParityQueryResponse(const yacl::Buffer& buf) { return {parity, proto.server_compute_time()}; } -} +} // namespace pir::piano diff --git a/experimental/pir/piano/server.cc b/experimental/pir/piano/server.cc index 82a9a63b..9d2516cf 100644 --- a/experimental/pir/piano/server.cc +++ b/experimental/pir/piano/server.cc @@ -101,4 +101,4 @@ std::vector QueryServiceServer::HandleSetParityQuery( return ret; } -} +} // namespace pir::piano diff --git a/experimental/pir/piano/server.h b/experimental/pir/piano/server.h index 04a1d2f3..04650afa 100644 --- a/experimental/pir/piano/server.h +++ b/experimental/pir/piano/server.h @@ -9,10 +9,9 @@ #include #include -#include "yacl/link/context.h" - #include "experimental/pir/piano/serialize.h" #include "experimental/pir/piano/util.h" +#include "yacl/link/context.h" namespace pir::piano { @@ -51,4 +50,4 @@ class QueryServiceServer { uint64_t chunk_size_; // The size of each chunk }; -} +} // namespace pir::piano diff --git a/experimental/pir/piano/util.cc b/experimental/pir/piano/util.cc index f81b9e26..cab249db 100644 --- a/experimental/pir/piano/util.cc +++ b/experimental/pir/piano/util.cc @@ -166,4 +166,4 @@ bool PRSetWithShortTag::MemberTestWithLongKeyAndTag( (PRFEvalWithLongKeyAndTag(long_key, Tag, chunk_id) & (chunk_size - 1)); } -} +} // namespace pir::piano diff --git a/experimental/pir/piano/util.h b/experimental/pir/piano/util.h index da20e381..48406423 100644 --- a/experimental/pir/piano/util.h +++ b/experimental/pir/piano/util.h @@ -85,4 +85,4 @@ class PRSetWithShortTag { uint64_t chunk_size) const; }; -} +} // namespace pir::piano diff --git a/psi/piano/BUILD.bazel b/psi/piano/BUILD.bazel deleted file mode 100644 index 3336de04..00000000 --- a/psi/piano/BUILD.bazel +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@rules_cc//cc:defs.bzl", "cc_proto_library") -load("@rules_proto//proto:defs.bzl", "proto_library") -load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -proto_library( - name = "piano_proto", - srcs = ["piano.proto"], -) - -cc_proto_library( - name = "piano_cc_proto", - deps = [":piano_proto"], -) - -psi_cc_library( - name = "util", - srcs = ["util.cc"], - hdrs = ["util.h"], - deps = [ - "@yacl//yacl/crypto/aes:aes_intrinsics", - ], -) - -psi_cc_library( - name = "serialize", - srcs = ["serialize.h"], - deps = [ - ":piano_cc_proto", - ":util", - "@yacl//yacl/base:buffer", - ], -) - -psi_cc_library( - name = "server", - srcs = ["server.cc"], - hdrs = ["server.h"], - deps = [ - ":piano_cc_proto", - ":serialize", - ":util", - "@yacl//yacl/link:context", - ], -) - -psi_cc_library( - name = "client", - srcs = ["client.cc"], - hdrs = ["client.h"], - deps = [ - ":piano_cc_proto", - ":serialize", - ":util", - "@yacl//yacl/crypto/rand", - "@yacl//yacl/link:context", - ], -) - -psi_cc_test( - name = "piano_test", - timeout = "eternal", - srcs = ["piano_test.cc"], - deps = [ - ":client", - ":server", - ":util", - "@yacl//yacl/crypto/rand", - "@yacl//yacl/link:context", - "@yacl//yacl/link:test_util", - ], -) - -psi_cc_binary( - name = "piano_benchmark", - srcs = ["piano_benchmark.cc"], - deps = [ - ":client", - ":server", - ":util", - "@com_github_google_benchmark//:benchmark_main", - "@yacl//yacl/link:context", - "@yacl//yacl/link:test_util", - ], -) diff --git a/psi/piano/README.md b/psi/piano/README.md deleted file mode 100644 index 65cc6d99..00000000 --- a/psi/piano/README.md +++ /dev/null @@ -1,28 +0,0 @@ -Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation - -论文地址:https://eprint.iacr.org/2023/452 - -论文开源实现:https://github.com/wuwuz/Piano-PIR-new - -**方案概括** - -1. 采用特定于客户端的预处理模型(也称为订阅模型),让每个客户端在预处理期间下载并存储来自服务器的“提示” - -2. 实现了O(√n)的客户端存储,和O(√n)的在线通信与计算开销(平均到每个查询上) -3. 服务器是诚实且好奇的,恶意服务器也无法侵害隐私,但会导致查询结果出错 -4. 方案包括两个阶段:预处理阶段和在线查询阶段 - -**预处理阶段** - -1. 服务器将数据库划分为O(√n)个块,客户端**流式**的从服务器获取每块数据,并每次只处理当前块中的元素,包括记录部分数据和计算奇偶校验位 -2. 客户端要存储的“提示”包括三类:主表,替换条目和备份表,主表共有O(√n)个,替换条目和备份表在每个块上要存储 - O(1)个 - -**在线查询阶段** -1. 客户端查询包含x的主表,将主表中的x替换为替换条目中该块的下一个可用元素,发送序列给服务器。服务器计算序列的奇偶校验位并返回,客户端在本地通过异或操作恢复出DB[x] -2. 主表使用一次后就要丢弃,使用备份表进行替换,同时为了负载均衡,要保留(x,DB[x]) - -**具体实现** -1. 客户端不存储完整的序列,而是只存储tag,通过msk和PRF扩展出完整的序列 -2. 每个块都有O(1)个备份表,备份表中该块对应的DB[x]没有参与计算奇偶校验位,以实现快速替换 -3. 本地保存查询记录,当有重复查询时在本地查询,并发送随机序列给服务器 diff --git a/psi/piano/client.cc b/psi/piano/client.cc deleted file mode 100644 index 9d8475ac..00000000 --- a/psi/piano/client.cc +++ /dev/null @@ -1,348 +0,0 @@ -#include "psi/piano/client.h" - -namespace psi::piano { - -uint64_t primaryNumParam(const double q, const double chunk_size, - const double target) { - const double k = std::ceil((std::log(2) * target) + std::log(q)); - return static_cast(k) * static_cast(chunk_size); -} - -double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num, - const uint64_t bin_size) { - const double mean = - static_cast(ball_num) / static_cast(bin_num); - const double c = (static_cast(bin_size) / mean) - 1; - // Chernoff bound exp(-(c^2)/(2+c) * mean) - double t = (mean * (c * c) / (2 + c)) * std::log(2); - t -= std::log2(static_cast(bin_num)); - return t; -} - -QueryServiceClient::QueryServiceClient( - const uint64_t db_size, const uint64_t thread_num, - std::shared_ptr context) - : db_size_(db_size), thread_num_(thread_num), context_(std::move(context)) { - Initialize(); - InitializeLocalSets(); -} - -void QueryServiceClient::Initialize() { - std::mt19937_64 rng(yacl::crypto::FastRandU64()); - - master_key_ = RandKey(rng); - long_key_ = GetLongKey(&master_key_); - - // Q = sqrt(n) * ln(n) - totalQueryNum = - static_cast(std::sqrt(static_cast(db_size_)) * - std::log(static_cast(db_size_))); - - std::tie(chunk_size_, set_size_) = GenParams(db_size_); - - primary_set_num_ = - primaryNumParam(static_cast(totalQueryNum), - static_cast(chunk_size_), FailureProbLog2 + 1); - // if localSetNum is not a multiple of thread_num_ then we need to add some - // padding - primary_set_num_ = - (primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_; - - backup_set_num_per_chunk_ = - 3 * static_cast(static_cast(totalQueryNum) / - static_cast(set_size_)); - backup_set_num_per_chunk_ = - (backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; - - // set_size == chunk_number - total_backup_set_num_ = backup_set_num_per_chunk_ * set_size_; -} - -void QueryServiceClient::InitializeLocalSets() { - primary_sets_.clear(); - primary_sets_.reserve(primary_set_num_); - local_backup_sets_.clear(); - local_backup_sets_.reserve(total_backup_set_num_); - local_cache_.clear(); - local_miss_elements_.clear(); - uint32_t tagCounter = 0; - - for (uint64_t j = 0; j < primary_set_num_; j++) { - primary_sets_.emplace_back(tagCounter, ZeroEntry(), 0, false); - tagCounter += 1; - } - - local_backup_set_groups_.clear(); - local_backup_set_groups_.reserve(set_size_); - local_replacement_groups_.clear(); - local_replacement_groups_.reserve(set_size_); - - for (uint64_t i = 0; i < set_size_; i++) { - std::vector> backupSets; - for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) { - backupSets.emplace_back( - local_backup_sets_[(i * backup_set_num_per_chunk_) + j]); - } - LocalBackupSetGroup backupGroup(0, backupSets); - local_backup_set_groups_.emplace_back(std::move(backupGroup)); - - std::vector indices(backup_set_num_per_chunk_); - std::vector values(backup_set_num_per_chunk_); - LocalReplacementGroup replacementGroup(0, indices, values); - local_replacement_groups_.emplace_back(std::move(replacementGroup)); - } - - for (uint64_t j = 0; j < set_size_; j++) { - for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { - local_backup_set_groups_[j].sets[k].get() = - LocalBackupSet{tagCounter, ZeroEntry()}; - tagCounter += 1; - } - } -} - -void QueryServiceClient::FetchFullDB() { - const auto fetchFullDBMsg = SerializeFetchFullDBMsg(1); - context_->SendAsync(context_->NextRank(), fetchFullDBMsg, "FetchFullDBMsg"); - - for (uint64_t i = 0; i < set_size_; i++) { - auto chunkBuf = context_->Recv(context_->NextRank(), "DBChunk"); - if (chunkBuf.size() == 0) { - break; - } - auto dbChunk = DeserializeDBChunk(chunkBuf); - auto& chunk = std::get<2>(dbChunk); - - std::vector hitMap(chunk_size_, false); - - // Use multiple threads to parallelize the computation for the chunk - std::vector threads; - std::mutex hitMapMutex; - - // make sure all sets are covered - const uint64_t perTheadSetNum = - ((primary_set_num_ + thread_num_ - 1) / thread_num_) + 1; - const uint64_t perThreadBackupNum = - ((total_backup_set_num_ + thread_num_ - 1) / thread_num_) + 1; - - for (uint64_t tid = 0; tid < thread_num_; tid++) { - uint64_t startIndex = tid * perTheadSetNum; - uint64_t endIndex = - std::min(startIndex + perTheadSetNum, primary_set_num_); - - uint64_t startIndexBackup = tid * perThreadBackupNum; - uint64_t endIndexBackup = std::min(startIndexBackup + perThreadBackupNum, - total_backup_set_num_); - - threads.emplace_back([&, startIndex, endIndex, startIndexBackup, - endIndexBackup] { - // update the parities for the primary hints - for (uint64_t j = startIndex; j < endIndex; j++) { - const auto tmp = - PRFEvalWithLongKeyAndTag(long_key_, primary_sets_[j].tag, i); - const auto offset = tmp & (chunk_size_ - 1); - { - std::lock_guard lock(hitMapMutex); - hitMap[offset] = true; - } - DBEntryXorFromRaw(&primary_sets_[j].parity, - &chunk[offset * DBEntryLength]); - } - - // update the parities for the backup hints - for (uint64_t j = startIndexBackup; j < endIndexBackup; j++) { - const auto tmp = - PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); - const auto offset = tmp & (chunk_size_ - 1); - DBEntryXorFromRaw(&local_backup_sets_[j].parityAfterPunct, - &chunk[offset * DBEntryLength]); - } - }); - } - - for (auto& thread : threads) { - if (thread.joinable()) { - thread.join(); - } - } - - // If any element is not hit, then it is a local miss. We will save it in - // the local miss cache. Most of the time, the local miss cache will be - // empty. - for (uint64_t j = 0; j < chunk_size_; j++) { - if (!hitMap[j]) { - std::array entry_slice{}; - std::memcpy(entry_slice.data(), &chunk[j * DBEntryLength], - DBEntryLength * sizeof(uint64_t)); - const auto entry = DBEntryFromSlice(entry_slice); - local_miss_elements_[j + (i * chunk_size_)] = entry; - } - } - - // For the i-th group of backups, leave the i-th chunk as blank - // To do that, we just xor the i-th chunk's value again - for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { - const auto tag = local_backup_set_groups_[i].sets[k].get().tag; - const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); - const auto offset = tmp & (chunk_size_ - 1); - DBEntryXorFromRaw( - &local_backup_set_groups_[i].sets[k].get().parityAfterPunct, - &chunk[offset * DBEntryLength]); - } - - // store the replacement - std::mt19937_64 rng(yacl::crypto::FastRandU64()); - for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { - // generate a random offset between 0 and ChunkSize - 1 - const auto offset = rng() & (chunk_size_ - 1); - local_replacement_groups_[i].indices[k] = offset + i * chunk_size_; - std::array entry_slice{}; - std::memcpy(entry_slice.data(), &chunk[offset * DBEntryLength], - DBEntryLength * sizeof(uint64_t)); - local_replacement_groups_[i].value[k] = DBEntryFromSlice(entry_slice); - } - } -} - -void QueryServiceClient::SendDummySet() const { - std::mt19937_64 rng(yacl::crypto::FastRandU64()); - std::vector randSet(set_size_); - for (uint64_t i = 0; i < set_size_; i++) { - randSet[i] = rng() % chunk_size_ + i * chunk_size_; - } - - // send the random dummy set to the server - const auto query_msg = SerializeSetParityQueryMsg(set_size_, randSet); - context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); - - const auto response_buf = - context_->Recv(context_->NextRank(), "SetParityQueryResponse"); - // auto parityQueryResponse = DeserializeSetParityQueryResponse(response_buf); -} - -DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { - // make sure x is not in the local cache - if (local_cache_.find(x) != local_cache_.end()) { - SendDummySet(); - return local_cache_[x]; - } - - // 1. Query x: the client first finds a local set that contains x - // 2. The client expands the set, replace the chunk(x)-th element to a - // replacement - // 3. The client sends the edited set to the server and gets the parity - // 4. The client recovers the answer - uint64_t hitSetId = std::numeric_limits::max(); - - const uint64_t queryOffset = x % chunk_size_; - const uint64_t chunkId = x / chunk_size_; - - for (uint64_t i = 0; i < primary_set_num_; i++) { - const auto& set = primary_sets_[i]; - if (const bool isProgrammedMatch = - set.isProgrammed && chunkId == (set.programmedPoint / chunk_size_); - !isProgrammedMatch && - PRSetWithShortTag{set.tag}.MemberTestWithLongKeyAndTag( - long_key_, chunkId, queryOffset, chunk_size_)) { - hitSetId = i; - break; - } - } - - DBEntry xVal = ZeroEntry(); - - if (hitSetId == std::numeric_limits::max()) { - if (local_miss_elements_.find(x) == local_miss_elements_.end()) { - SPDLOG_ERROR("No hit set found for %lu", x); - } else { - xVal = local_miss_elements_[x]; - local_cache_[x] = xVal; - } - - SendDummySet(); - return xVal; - } - - // expand the set - const PRSetWithShortTag set{primary_sets_[hitSetId].tag}; - auto expandedSet = set.ExpandWithLongKey(long_key_, set_size_, chunk_size_); - - // manually program the set if the flag is set before - if (primary_sets_[hitSetId].isProgrammed) { - const uint64_t programmedChunkId = - primary_sets_[hitSetId].programmedPoint / chunk_size_; - expandedSet[programmedChunkId] = primary_sets_[hitSetId].programmedPoint; - } - - // edit the set by replacing the chunk(x)-th element with a replacement - const uint64_t nxtAvailable = local_replacement_groups_[chunkId].consumed; - if (nxtAvailable == backup_set_num_per_chunk_) { - SPDLOG_ERROR("No replacement available for %lu", x); - SendDummySet(); - return xVal; - } - - // consume one replacement - const uint64_t repIndex = - local_replacement_groups_[chunkId].indices[nxtAvailable]; - const DBEntry repVal = local_replacement_groups_[chunkId].value[nxtAvailable]; - local_replacement_groups_[chunkId].consumed++; - expandedSet[chunkId] = repIndex; - - // send the edited set to the server - const auto query_msg = SerializeSetParityQueryMsg(set_size_, expandedSet); - context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); - - const auto response_buf = - context_->Recv(context_->NextRank(), "SetParityQueryResponse"); - - const auto parityQueryResponse = - DeserializeSetParityQueryResponse(response_buf); - const auto& parity = std::get<0>(parityQueryResponse); - - // recover the answer - xVal = primary_sets_[hitSetId].parity; // the parity of the hit set - DBEntryXorFromRaw(&xVal, parity.data()); // xor the parity of the edited set - DBEntryXor(&xVal, &repVal); // xor the replacement value - - // update the local cache - local_cache_[x] = xVal; - - // refresh phase - if (local_backup_set_groups_[chunkId].consumed == backup_set_num_per_chunk_) { - SPDLOG_WARN("No backup set available for %lu", x); - return xVal; - } - - const DBEntry originalXVal = xVal; - const uint64_t consumed = local_backup_set_groups_[chunkId].consumed; - primary_sets_[hitSetId].tag = - local_backup_set_groups_[chunkId].sets[consumed].get().tag; - // backup set doesn't XOR the chunk(x)-th element in preparation - DBEntryXor( - &xVal, - &local_backup_set_groups_[chunkId].sets[consumed].get().parityAfterPunct); - primary_sets_[hitSetId].parity = xVal; - primary_sets_[hitSetId].isProgrammed = true; - // for load balancing, the chunk(x)-th element differs from the one expanded - // via PRFEval on the tag - primary_sets_[hitSetId].programmedPoint = x; - local_backup_set_groups_[chunkId].consumed++; - - return originalXVal; -} - -std::vector QueryServiceClient::OnlineMultipleQueries( - const std::vector& queries) { - std::vector results; - results.reserve(queries.size()); - - for (const auto& x : queries) { - DBEntry result = OnlineSingleQuery(x); - results.push_back(result); - } - - return results; -} - -} // namespace psi::piano diff --git a/psi/piano/client.h b/psi/piano/client.h deleted file mode 100644 index b2089b8b..00000000 --- a/psi/piano/client.h +++ /dev/null @@ -1,104 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include - -#include "yacl/crypto/rand/rand.h" -#include "yacl/link/context.h" - -#include "psi/piano/serialize.h" -#include "psi/piano/util.h" - -namespace psi::piano { - -class LocalSet { - public: - uint32_t tag; // the tag of the set - DBEntry parity; - uint64_t - programmedPoint; // identifier for the element replaced after refresh, - // differing from those expanded by PRFEval - bool isProgrammed; - - LocalSet(const uint32_t tag, const DBEntry& parity, - const uint64_t programmed_point, const bool is_programmed) - : tag(tag), - parity(parity), - programmedPoint(programmed_point), - isProgrammed(is_programmed) {} -}; - -class LocalBackupSet { - public: - uint32_t tag; - DBEntry parityAfterPunct; - - LocalBackupSet(const uint32_t tag, const DBEntry& parity_after_punct) - : tag(tag), parityAfterPunct(parity_after_punct) {} -}; - -class LocalBackupSetGroup { - public: - uint64_t consumed; - std::vector> sets; - - LocalBackupSetGroup( - const uint64_t consumed, - const std::vector>& sets) - : consumed(consumed), sets(sets) {} -}; - -class LocalReplacementGroup { - public: - uint64_t consumed; - std::vector indices; - std::vector value; - - LocalReplacementGroup(const uint64_t consumed, - const std::vector& indices, - const std::vector& value) - : consumed(consumed), indices(indices), value(value) {} -}; - -class QueryServiceClient { - public: - static constexpr uint64_t FailureProbLog2 = 40; - uint64_t totalQueryNum{}; - - QueryServiceClient(uint64_t db_size, uint64_t thread_num, - std::shared_ptr context); - - void Initialize(); - void InitializeLocalSets(); - void FetchFullDB(); - void SendDummySet() const; - DBEntry OnlineSingleQuery(uint64_t x); - std::vector OnlineMultipleQueries( - const std::vector& queries); - - private: - uint64_t db_size_; - uint64_t thread_num_; - std::shared_ptr context_; - - uint64_t chunk_size_{}; - uint64_t set_size_{}; - uint64_t primary_set_num_{}; - uint64_t backup_set_num_per_chunk_{}; - uint64_t total_backup_set_num_{}; - PrfKey master_key_{}; - yacl::crypto::AES_KEY long_key_{}; - - std::vector primary_sets_; - std::vector local_backup_sets_; - std::map local_cache_; - std::map local_miss_elements_; - std::vector local_backup_set_groups_; - std::vector local_replacement_groups_; -}; - -} // namespace psi::piano diff --git a/psi/piano/piano.proto b/psi/piano/piano.proto deleted file mode 100644 index aa9cac3f..00000000 --- a/psi/piano/piano.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package psi.piano; - -message FetchFullDbMsg { - uint64 dummy = 1; -} - -message DbChunk { - uint64 chunk_id = 1; - uint64 chunk_size = 2; - repeated uint64 chunks = 3; -} - -message SetParityQueryMsg { - uint64 set_size = 1; - repeated uint64 indices = 2; -} - -message SetParityQueryResponse { - repeated uint64 parity = 1; - uint64 server_compute_time = 2; -} - -message QueryRequest { - oneof request { - FetchFullDbMsg fetch_full_db = 1; - SetParityQueryMsg set_parity_query = 2; - } -} diff --git a/psi/piano/piano_benchmark.cc b/psi/piano/piano_benchmark.cc deleted file mode 100644 index 2c16d08d..00000000 --- a/psi/piano/piano_benchmark.cc +++ /dev/null @@ -1,103 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "benchmark/benchmark.h" -#include "yacl/link/context.h" -#include "yacl/link/test_util.h" - -#include "psi/piano/client.h" -#include "psi/piano/server.h" -#include "psi/piano/util.h" - -namespace { - -std::vector GenerateQueries(const uint64_t query_num, - const uint64_t db_size) { - std::vector queries; - queries.reserve(query_num); - - std::mt19937_64 rng(yacl::crypto::FastRandU64()); - for (uint64_t q = 0; q < query_num; ++q) { - queries.push_back(rng() % db_size); - } - - return queries; -} - -std::vector CreateDatabase(const uint64_t db_size, - const uint64_t db_seed) { - const auto [ChunkSize, SetSize] = psi::piano::GenParams(db_size); - std::vector DB; - DB.assign(ChunkSize * SetSize * psi::piano::DBEntryLength, 0); - - for (uint64_t i = 0; i < DB.size() / psi::piano::DBEntryLength; ++i) { - auto entry = psi::piano::GenDBEntry(db_seed, i); - std::memcpy(&DB[i * psi::piano::DBEntryLength], entry.data(), - psi::piano::DBEntryLength * sizeof(uint64_t)); - } - - return DB; -} - -void SetupAndRunServer( - const std::shared_ptr& server_context, - const uint64_t db_size, std::promise& exit_signal, - std::vector& db) { - const auto [ChunkSize, SetSize] = psi::piano::GenParams(db_size); - psi::piano::QueryServiceServer server(db, server_context, SetSize, ChunkSize); - server.Start(exit_signal.get_future()); -} - -std::vector SetupAndRunClient( - const uint64_t db_size, const uint64_t thread_num, - const std::shared_ptr& client_context, - const std::vector& queries) { - psi::piano::QueryServiceClient client(db_size, thread_num, client_context); - client.FetchFullDB(); - return client.OnlineMultipleQueries(queries); -} - -} // namespace - -static void BM_PianoPir(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - uint64_t db_size = state.range(0) / sizeof(psi::piano::DBEntry); - const uint64_t query_num = state.range(1); - constexpr uint64_t db_seed = 2315127; - uint64_t thread_num = 8; - - constexpr int kWorldSize = 2; - const auto contexts = yacl::link::test::SetupWorld(kWorldSize); - yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); - - auto db = CreateDatabase(db_size, db_seed); - auto queries = GenerateQueries(query_num, db_size); - - state.ResumeTiming(); - std::promise exitSignal; - auto server_future = - std::async(std::launch::async, SetupAndRunServer, contexts[0], db_size, - std::ref(exitSignal), std::ref(db)); - - auto client_future = - std::async(std::launch::async, SetupAndRunClient, db_size, thread_num, - contexts[1], std::cref(queries)); - auto results = client_future.get(); - - exitSignal.set_value(); - server_future.get(); - } -} - -// [1m, 16m, 64m, 128m] -BENCHMARK(BM_PianoPir) - ->Unit(benchmark::kMillisecond) - ->Args({1 << 20, 1000}) - ->Args({16 << 20, 1000}) - ->Args({64 << 20, 1000}) - ->Args({128 << 20, 1000}); diff --git a/psi/piano/piano_test.cc b/psi/piano/piano_test.cc deleted file mode 100644 index 9fe33643..00000000 --- a/psi/piano/piano_test.cc +++ /dev/null @@ -1,120 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "yacl/link/context.h" -#include "yacl/link/test_util.h" - -#include "psi/piano/client.h" -#include "psi/piano/serialize.h" -#include "psi/piano/server.h" -#include "psi/piano/util.h" - -struct TestParams { - uint64_t db_size; - uint64_t db_seed; - uint64_t thread_num; - uint64_t query_num; - bool is_total_query_num; -}; - -namespace psi::piano { - -std::vector GenerateQueries(const uint64_t query_num, - const uint64_t db_size) { - std::vector queries; - queries.reserve(query_num); - - std::mt19937_64 rng(yacl::crypto::FastRandU64()); - for (uint64_t q = 0; q < query_num; ++q) { - queries.push_back(rng() % db_size); - } - - return queries; -} - -std::vector RunClient(QueryServiceClient& client, - const std::vector& queries) { - client.FetchFullDB(); - return client.OnlineMultipleQueries(queries); -} - -std::vector getResults(const std::vector& queries, - const TestParams& params) { - std::vector expected_results; - expected_results.reserve(queries.size()); - - for (const auto& x : queries) { - expected_results.push_back(GenDBEntry(params.db_seed, x)); - } - - return expected_results; -} - -class PianoTest : public testing::TestWithParam {}; - -TEST_P(PianoTest, Works) { - auto params = GetParam(); - constexpr int kWorldSize = 2; - const auto contexts = yacl::link::test::SetupWorld(kWorldSize); - - SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", - params.db_size, DBEntrySize, - params.db_size * DBEntrySize / 1024 / 1024); - - auto [ChunkSize, SetSize] = GenParams(params.db_size); - SPDLOG_INFO("Chunk Size: %lu, Set Size: %lu\n", ChunkSize, SetSize); - - std::vector DB; - DB.assign(ChunkSize * SetSize * DBEntryLength, 0); - SPDLOG_INFO("DB Real N: %lu\n", DB.size()); - - for (uint64_t i = 0; i < DB.size() / DBEntryLength; ++i) { - auto entry = GenDBEntry(params.db_seed, i); - std::memcpy(&DB[i * DBEntryLength], entry.data(), - DBEntryLength * sizeof(uint64_t)); - } - - QueryServiceClient client(params.db_size, params.thread_num, contexts[1]); - - const auto actual_query_num = - params.is_total_query_num ? client.totalQueryNum : params.query_num; - auto queries = GenerateQueries(actual_query_num, DB.size()); - - yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); - QueryServiceServer server(DB, contexts[0], SetSize, ChunkSize); - - std::promise exitSignal; - std::future futureObj = exitSignal.get_future(); - auto server_future = - std::async(std::launch::async, &QueryServiceServer::Start, - std::ref(server), std::move(futureObj)); - auto client_future = std::async(std::launch::async, RunClient, - std::ref(client), std::cref(queries)); - - auto results = client_future.get(); - auto expected_results = getResults(queries, params); - - for (size_t i = 0; i < results.size(); ++i) { - EXPECT_EQ(results[i], expected_results[i]) - << "Mismatch at index " << queries[i]; - } - - exitSignal.set_value(); - server_future.get(); -} - -// [8m, 128m, 256m] -INSTANTIATE_TEST_SUITE_P( - PianoTestInstances, PianoTest, - ::testing::Values(TestParams{131072, 1211212, 8, 1000, false}, - TestParams{2097152, 6405285, 8, 1000, false}, - TestParams{4194304, 7539870, 16, 1000, false})); -} // namespace psi::piano diff --git a/psi/piano/serialize.h b/psi/piano/serialize.h deleted file mode 100644 index c09d1cbd..00000000 --- a/psi/piano/serialize.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "yacl/base/buffer.h" - -#include "psi/piano/util.h" - -#include "psi/piano/piano.pb.h" - -namespace psi::piano { - -inline yacl::Buffer SerializeFetchFullDBMsg(const uint64_t dummy) { - QueryRequest proto; - FetchFullDbMsg* fetch_full_db_msg = proto.mutable_fetch_full_db(); - fetch_full_db_msg->set_dummy(dummy); - - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); - - return buf; -} - -inline uint64_t DeserializeFetchFullDBMsg(const yacl::Buffer& buf) { - QueryRequest proto; - proto.ParseFromArray(buf.data(), buf.size()); - return proto.fetch_full_db().dummy(); -} - -inline yacl::Buffer SerializeDBChunk(const uint64_t chunk_id, - const uint64_t chunk_size, - const std::vector& chunk) { - DbChunk proto; - proto.set_chunk_id(chunk_id); - proto.set_chunk_size(chunk_size); - for (const auto& val : chunk) { - proto.add_chunks(val); - } - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); - return buf; -} - -inline std::tuple> DeserializeDBChunk( - const yacl::Buffer& buf) { - DbChunk proto; - proto.ParseFromArray(buf.data(), buf.size()); - std::vector chunk(proto.chunks().begin(), proto.chunks().end()); - return {proto.chunk_id(), proto.chunk_size(), chunk}; -} - -inline yacl::Buffer SerializeSetParityQueryMsg( - const uint64_t set_size, const std::vector& indices) { - QueryRequest proto; - SetParityQueryMsg* set_parity_query = proto.mutable_set_parity_query(); - set_parity_query->set_set_size(set_size); - for (const auto& index : indices) { - set_parity_query->add_indices(index); - } - - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); - - return buf; -} - -inline std::pair> DeserializeSetParityQueryMsg( - const yacl::Buffer& buf) { - QueryRequest proto; - proto.ParseFromArray(buf.data(), buf.size()); - const auto& set_parity_query = proto.set_parity_query(); - std::vector indices(set_parity_query.indices().begin(), - set_parity_query.indices().end()); - return {set_parity_query.set_size(), indices}; -} - -inline yacl::Buffer SerializeSetParityQueryResponse( - const std::vector& parity, const uint64_t server_compute_time) { - SetParityQueryResponse proto; - for (const auto& p : parity) { - proto.add_parity(p); - } - proto.set_server_compute_time(server_compute_time); - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); - return buf; -} - -inline std::pair, uint64_t> -DeserializeSetParityQueryResponse(const yacl::Buffer& buf) { - SetParityQueryResponse proto; - proto.ParseFromArray(buf.data(), buf.size()); - std::vector parity(proto.parity().begin(), proto.parity().end()); - return {parity, proto.server_compute_time()}; -} - -} // namespace psi::piano diff --git a/psi/piano/server.cc b/psi/piano/server.cc deleted file mode 100644 index 2c835296..00000000 --- a/psi/piano/server.cc +++ /dev/null @@ -1,104 +0,0 @@ -#include "psi/piano/server.h" - -namespace psi::piano { - -QueryServiceServer::QueryServiceServer( - std::vector& db, std::shared_ptr context, - const uint64_t set_size, const uint64_t chunk_size) - : db_(std::move(db)), - context_(std::move(context)), - set_size_(set_size), - chunk_size_(chunk_size) {} - -void QueryServiceServer::Start(const std::future& stop_signal) { - while (stop_signal.wait_for(std::chrono::milliseconds(1)) == - std::future_status::timeout) { - auto request_data = context_->Recv(context_->NextRank(), "request_data"); - HandleRequest(request_data); - } -} - -void QueryServiceServer::HandleRequest(const yacl::Buffer& request_data) { - QueryRequest proto; - proto.ParseFromArray(request_data.data(), request_data.size()); - - switch (proto.request_case()) { - case QueryRequest::kFetchFullDb: { - // uint64_t dummy = DeserializeFetchFullDBMsg(request_data); - ProcessFetchFullDB(); - break; - } - case QueryRequest::kSetParityQuery: { - const auto parityQuery = DeserializeSetParityQueryMsg(request_data); - const auto& indices = std::get<1>(parityQuery); - - auto [parity, server_compute_time] = ProcessSetParityQuery(indices); - const auto response_buf = - SerializeSetParityQueryResponse(parity, server_compute_time); - context_->SendAsync(context_->NextRank(), response_buf, - "SetParityQueryResponse"); - break; - } - default: - SPDLOG_ERROR("Unknown request type."); - } -} - -void QueryServiceServer::ProcessFetchFullDB() { - for (uint64_t i = 0; i < set_size_; ++i) { - const uint64_t down = i * chunk_size_; - uint64_t up = (i + 1) * chunk_size_; - up = std::min(up, static_cast(db_.size())); - std::vector chunk(db_.begin() + down * DBEntryLength, - db_.begin() + up * DBEntryLength); - auto chunk_buf = SerializeDBChunk(i, chunk.size(), chunk); - - try { - context_->SendAsync(context_->NextRank(), chunk_buf, "FetchFullDBChunk"); - } catch (const std::exception& e) { - SPDLOG_ERROR("Failed to send a chunk."); - return; - } - } -} - -std::pair, uint64_t> -QueryServiceServer::ProcessSetParityQuery( - const std::vector& indices) { - const auto start = std::chrono::high_resolution_clock::now(); - std::vector parity = HandleSetParityQuery(indices); - const auto end = std::chrono::high_resolution_clock::now(); - const auto duration = - std::chrono::duration_cast(end - start).count(); - return {parity, duration}; -} - -DBEntry QueryServiceServer::DBAccess(const uint64_t id) { - if (id < db_.size()) { - if (id * DBEntryLength + DBEntryLength > db_.size()) { - SPDLOG_ERROR("DBAccess: id {} out of range", id); - } - std::array slice{}; - std::copy(db_.begin() + id * DBEntryLength, - db_.begin() + (id + 1) * DBEntryLength, slice.begin()); - return DBEntryFromSlice(slice); - } - DBEntry ret; - ret.fill(0); - return ret; -} - -std::vector QueryServiceServer::HandleSetParityQuery( - const std::vector& indices) { - DBEntry parity = ZeroEntry(); - for (const auto& index : indices) { - DBEntry entry = DBAccess(index); - DBEntryXor(&parity, &entry); - } - - std::vector ret(DBEntryLength); - std::copy(parity.begin(), parity.end(), ret.begin()); - return ret; -} - -} // namespace psi::piano diff --git a/psi/piano/server.h b/psi/piano/server.h deleted file mode 100644 index 9abe93d6..00000000 --- a/psi/piano/server.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -#include "yacl/link/context.h" - -#include "psi/piano/serialize.h" -#include "psi/piano/util.h" - -namespace psi::piano { - -class QueryServiceServer { - public: - // Constructor: initializes the server with a database, context, set_size, and - // chunk_size - QueryServiceServer(std::vector& db, - std::shared_ptr context, - uint64_t set_size, uint64_t chunk_size); - - // Starts the server to handle incoming requests - void Start(const std::future& stop_signal); - - // Handles the incoming request based on its type - void HandleRequest(const yacl::Buffer& request_data); - - // Processes a request to fetch the full database - void ProcessFetchFullDB(); - - // Processes a set parity query and returns the parity and server compute time - std::pair, uint64_t> ProcessSetParityQuery( - const std::vector& indices); - - private: - // Accesses the database and returns the corresponding entry - DBEntry DBAccess(uint64_t id); - - // Handles a set parity query and returns the parity - std::vector HandleSetParityQuery( - const std::vector& indices); - - std::vector db_; // The database - std::shared_ptr context_; // The communication context - uint64_t set_size_; // The size of the set - uint64_t chunk_size_; // The size of each chunk -}; - -} // namespace psi::piano diff --git a/psi/piano/util.cc b/psi/piano/util.cc deleted file mode 100644 index 6ed5a9c7..00000000 --- a/psi/piano/util.cc +++ /dev/null @@ -1,169 +0,0 @@ -#include "psi/piano/util.h" - -namespace psi::piano { - -uint128_t BytesToUint128(const std::string& bytes) { - if (bytes.size() != 16) { - SPDLOG_WARN("Bytes size must be 16 for uint128_t conversion."); - } - - uint128_t result = 0; - std::memcpy(&result, bytes.data(), 16); - return result; -} - -std::string Uint128ToBytes(const uint128_t value) { - std::string bytes(16, 0); - std::memcpy(bytes.data(), &value, 16); - return bytes; -} - -PrfKey128 RandKey128(std::mt19937_64& rng) { - const uint64_t lo = rng(); - const uint64_t hi = rng(); - return yacl::MakeUint128(hi, lo); -} - -PrfKey RandKey(std::mt19937_64& rng) { return RandKey128(rng); } - -uint64_t PRFEval128(const PrfKey128* key, const uint64_t x) { - yacl::crypto::AES_KEY aes_key; - AES_set_encrypt_key(*key, &aes_key); - - const auto src_block = static_cast(x); - std::vector plain_blocks(1); - plain_blocks[0] = src_block; - std::vector cipher_blocks(1); - - AES_ecb_encrypt_blks(aes_key, absl::MakeConstSpan(plain_blocks), - absl::MakeSpan(cipher_blocks)); - return static_cast(cipher_blocks[0]); -} - -uint64_t PRFEval(const PrfKey* key, const uint64_t x) { - return PRFEval128(key, x); -} - -void DBEntryXor(DBEntry* dst, const DBEntry* src) { - for (size_t i = 0; i < DBEntryLength; ++i) { - (*dst)[i] ^= (*src)[i]; - } -} - -void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src) { - for (size_t i = 0; i < DBEntryLength; ++i) { - (*dst)[i] ^= src[i]; - } -} - -bool EntryIsEqual(const DBEntry& a, const DBEntry& b) { - for (size_t i = 0; i < DBEntryLength; ++i) { - if (a[i] != b[i]) { - return false; - } - } - return true; -} - -DBEntry RandDBEntry(std::mt19937_64& rng) { - DBEntry entry; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = rng(); - } - return entry; -} - -uint64_t DefaultHash(uint64_t key) { - constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; - uint64_t hash = FNV_offset_basis; - for (int i = 0; i < 8; ++i) { - constexpr uint64_t FNV_prime = 1099511628211ULL; - const auto byte = static_cast(key & 0xFF); - hash ^= static_cast(byte); - hash *= FNV_prime; - key >>= 8; - } - return hash; -} - -DBEntry GenDBEntry(const uint64_t key, const uint64_t id) { - DBEntry entry; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = DefaultHash((key ^ id) + i); - } - return entry; -} - -DBEntry ZeroEntry() { - DBEntry entry = {}; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = 0; - } - return entry; -} - -DBEntry DBEntryFromSlice(const std::array& s) { - DBEntry entry; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = s[i]; - } - return entry; -} - -// Generate ChunkSize and SetSize -std::pair GenParams(const uint64_t db_size) { - const double targetChunkSize = 2 * std::sqrt(static_cast(db_size)); - uint64_t ChunkSize = 1; - - // Ensure ChunkSize is a power of 2 and not smaller than targetChunkSize - while (ChunkSize < static_cast(targetChunkSize)) { - ChunkSize *= 2; - } - - uint64_t SetSize = (db_size + ChunkSize - 1) / ChunkSize; - // Round up to the next multiple of 4 - SetSize = (SetSize + 3) / 4 * 4; - - return {ChunkSize, SetSize}; -} - -yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key) { - yacl::crypto::AES_KEY aes_key; - AES_set_encrypt_key(*key, &aes_key); - return aes_key; -} - -uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, - const uint32_t tag, const uint64_t x) { - // Combine tag and x into a 128-bit block by shifting tag to the high 64 bits - const uint128_t src_block = (static_cast(tag) << 64) + x; - std::vector plain_blocks(1); - plain_blocks[0] = src_block; - std::vector cipher_blocks(1); - AES_ecb_encrypt_blks(long_key, absl::MakeConstSpan(plain_blocks), - absl::MakeSpan(cipher_blocks)); - return static_cast(cipher_blocks[0]); -} - -std::vector PRSetWithShortTag::ExpandWithLongKey( - const yacl::crypto::AES_KEY& long_key, const uint64_t set_size, - const uint64_t chunk_size) const { - std::vector expandedSet(set_size); - for (uint64_t i = 0; i < set_size; i++) { - const uint64_t tmp = PRFEvalWithLongKeyAndTag(long_key, Tag, i); - // Get the offset within the chunk - const uint64_t offset = tmp & (chunk_size - 1); - expandedSet[i] = i * chunk_size + offset; - } - return expandedSet; -} - -bool PRSetWithShortTag::MemberTestWithLongKeyAndTag( - const yacl::crypto::AES_KEY& long_key, const uint64_t chunk_id, - const uint64_t offset, const uint64_t chunk_size) const { - // Ensure chunk_size is a power of 2 and compare offsets - return offset == - (PRFEvalWithLongKeyAndTag(long_key, Tag, chunk_id) & (chunk_size - 1)); -} - -} // namespace psi::piano diff --git a/psi/piano/util.h b/psi/piano/util.h deleted file mode 100644 index 9158224c..00000000 --- a/psi/piano/util.h +++ /dev/null @@ -1,88 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include - -#include "yacl/crypto/aes/aes_intrinsics.h" - -namespace psi::piano { - -constexpr size_t DBEntrySize = 8; // has to be a multiple of 8 -constexpr size_t DBEntryLength = DBEntrySize / 8; - -using PrfKey128 = uint128_t; -using DBEntry = std::array; -using PrfKey = PrfKey128; - -uint128_t BytesToUint128(const std::string& bytes); - -std::string Uint128ToBytes(uint128_t value); - -// Generates a random 128-bit key using the provided RNG -PrfKey128 RandKey128(std::mt19937_64& rng); - -// Generates a random PRF key -PrfKey RandKey(std::mt19937_64& rng); - -// Evaluates PRF using 128-bit key and returns a 64-bit result -uint64_t PRFEval128(const PrfKey128* key, uint64_t x); - -// Evaluates PRF using a general PrfKey and returns a 64-bit result -uint64_t PRFEval(const PrfKey* key, uint64_t x); - -// XOR two DBEntry structures -void DBEntryXor(DBEntry* dst, const DBEntry* src); - -// XOR a DBEntry with raw uint64_t data -void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src); - -// Compare two DBEntry structures for equality -bool EntryIsEqual(const DBEntry& a, const DBEntry& b); - -// Generate a random DBEntry using the provided RNG -DBEntry RandDBEntry(std::mt19937_64& rng); - -// Default FNV hash implementation for 64-bit keys -uint64_t DefaultHash(uint64_t key); - -// Generate a DBEntry based on a key and ID -DBEntry GenDBEntry(uint64_t key, uint64_t id); - -// Generate a zero-filled DBEntry -DBEntry ZeroEntry(); - -// Convert a slice (array) into a DBEntry structure -DBEntry DBEntryFromSlice(const std::array& s); - -// Generate parameters for ChunkSize and SetSize -std::pair GenParams(uint64_t db_size); - -// Returns a long key (AES expanded key) for PRF evaluation -yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key); - -// PRF evaluation with a long key and tag, returns a 64-bit result -uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, - uint32_t tag, uint64_t x); - -class PRSetWithShortTag { - public: - uint32_t Tag; - - // Expands the set with a long key and tag - [[nodiscard]] std::vector ExpandWithLongKey( - const yacl::crypto::AES_KEY& long_key, uint64_t set_size, - uint64_t chunk_size) const; - - // Membership test with a long key and tag, to check if an ID belongs to the - // set - [[nodiscard]] bool MemberTestWithLongKeyAndTag( - const yacl::crypto::AES_KEY& long_key, uint64_t chunk_id, uint64_t offset, - uint64_t chunk_size) const; -}; - -} // namespace psi::piano From 14b90b40cfb675d74299a42fc1144384cd4c9607 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Tue, 26 Nov 2024 17:48:17 +0800 Subject: [PATCH 07/11] Abstract DBEntry into a class, use prg in yacl, make member variables private, etc --- experimental/pir/piano/BUILD.bazel | 3 +- experimental/pir/piano/client.cc | 101 +++++++++--------- experimental/pir/piano/client.h | 75 +++++++------- experimental/pir/piano/piano.proto | 4 +- experimental/pir/piano/piano_benchmark.cc | 87 ++++++++++------ experimental/pir/piano/piano_test.cc | 78 +++++++++----- experimental/pir/piano/serialize.h | 20 ++-- experimental/pir/piano/server.cc | 53 +++++----- experimental/pir/piano/server.h | 13 +-- experimental/pir/piano/util.cc | 108 ++------------------ experimental/pir/piano/util.h | 118 +++++++++++++--------- 11 files changed, 313 insertions(+), 347 deletions(-) diff --git a/experimental/pir/piano/BUILD.bazel b/experimental/pir/piano/BUILD.bazel index 3336de04..183e7b2f 100644 --- a/experimental/pir/piano/BUILD.bazel +++ b/experimental/pir/piano/BUILD.bazel @@ -34,6 +34,7 @@ psi_cc_library( hdrs = ["util.h"], deps = [ "@yacl//yacl/crypto/aes:aes_intrinsics", + "@yacl//yacl/crypto/rand", ], ) @@ -67,7 +68,6 @@ psi_cc_library( ":piano_cc_proto", ":serialize", ":util", - "@yacl//yacl/crypto/rand", "@yacl//yacl/link:context", ], ) @@ -80,7 +80,6 @@ psi_cc_test( ":client", ":server", ":util", - "@yacl//yacl/crypto/rand", "@yacl//yacl/link:context", "@yacl//yacl/link:test_util", ], diff --git a/experimental/pir/piano/client.cc b/experimental/pir/piano/client.cc index 7875f1c6..d542d9e7 100644 --- a/experimental/pir/piano/client.cc +++ b/experimental/pir/piano/client.cc @@ -20,36 +20,37 @@ double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num, } QueryServiceClient::QueryServiceClient( - const uint64_t db_size, const uint64_t thread_num, - std::shared_ptr context) - : db_size_(db_size), thread_num_(thread_num), context_(std::move(context)) { + const uint64_t entry_num, const uint64_t thread_num, + const uint64_t entry_size, std::shared_ptr context) + : entry_num_(entry_num), + thread_num_(thread_num), + context_(std::move(context)), + entry_size_(entry_size) { Initialize(); InitializeLocalSets(); } void QueryServiceClient::Initialize() { - std::mt19937_64 rng(yacl::crypto::FastRandU64()); - - master_key_ = RandKey(rng); - long_key_ = GetLongKey(&master_key_); + master_key_ = SecureRandKey(); + long_key_ = GetLongKey(master_key_); // Q = sqrt(n) * ln(n) - totalQueryNum = - static_cast(std::sqrt(static_cast(db_size_)) * - std::log(static_cast(db_size_))); + total_query_num_ = + static_cast(std::sqrt(static_cast(entry_num_)) * + std::log(static_cast(entry_num_))); - std::tie(chunk_size_, set_size_) = GenParams(db_size_); + std::tie(chunk_size_, set_size_) = GenParams(entry_num_); primary_set_num_ = - primaryNumParam(static_cast(totalQueryNum), - static_cast(chunk_size_), FailureProbLog2 + 1); + primaryNumParam(static_cast(total_query_num_), + static_cast(chunk_size_), kFailureProbLog2 + 1); // if localSetNum is not a multiple of thread_num_ then we need to add some // padding primary_set_num_ = (primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_; backup_set_num_per_chunk_ = - 3 * static_cast(static_cast(totalQueryNum) / + 3 * static_cast(static_cast(total_query_num_) / static_cast(set_size_)); backup_set_num_per_chunk_ = (backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; @@ -67,8 +68,17 @@ void QueryServiceClient::InitializeLocalSets() { local_miss_elements_.clear(); uint32_t tagCounter = 0; + // Initialize primary_sets_ for (uint64_t j = 0; j < primary_set_num_; j++) { - primary_sets_.emplace_back(tagCounter, ZeroEntry(), 0, false); + primary_sets_.emplace_back(tagCounter, DBEntry::ZeroEntry(entry_size_), 0, + false); + tagCounter += 1; + } + + // Initialize local_backup_sets_ + for (uint64_t i = 0; i < total_backup_set_num_; ++i) { + local_backup_sets_.emplace_back(tagCounter, + DBEntry::ZeroEntry(entry_size_)); tagCounter += 1; } @@ -77,6 +87,7 @@ void QueryServiceClient::InitializeLocalSets() { local_replacement_groups_.clear(); local_replacement_groups_.reserve(set_size_); + // Initialize local_backup_set_groups_ and local_replacement_groups_ for (uint64_t i = 0; i < set_size_; i++) { std::vector> backupSets; for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) { @@ -91,14 +102,6 @@ void QueryServiceClient::InitializeLocalSets() { LocalReplacementGroup replacementGroup(0, indices, values); local_replacement_groups_.emplace_back(std::move(replacementGroup)); } - - for (uint64_t j = 0; j < set_size_; j++) { - for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { - local_backup_set_groups_[j].sets[k].get() = - LocalBackupSet{tagCounter, ZeroEntry()}; - tagCounter += 1; - } - } } void QueryServiceClient::FetchFullDB() { @@ -145,8 +148,7 @@ void QueryServiceClient::FetchFullDB() { std::lock_guard lock(hitMapMutex); hitMap[offset] = true; } - DBEntryXorFromRaw(&primary_sets_[j].parity, - &chunk[offset * DBEntryLength]); + primary_sets_[j].parity.XorFromRaw(&chunk[offset * entry_size_]); } // update the parities for the backup hints @@ -154,8 +156,8 @@ void QueryServiceClient::FetchFullDB() { const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); const auto offset = tmp & (chunk_size_ - 1); - DBEntryXorFromRaw(&local_backup_sets_[j].parityAfterPunct, - &chunk[offset * DBEntryLength]); + local_backup_sets_[j].parityAfterPuncture.XorFromRaw( + &chunk[offset * entry_size_]); } }); } @@ -171,10 +173,10 @@ void QueryServiceClient::FetchFullDB() { // empty. for (uint64_t j = 0; j < chunk_size_; j++) { if (!hitMap[j]) { - std::array entry_slice{}; - std::memcpy(entry_slice.data(), &chunk[j * DBEntryLength], - DBEntryLength * sizeof(uint64_t)); - const auto entry = DBEntryFromSlice(entry_slice); + std::vector entry_slice(entry_size_); + std::memcpy(entry_slice.data(), &chunk[j * entry_size_], + entry_size_ * sizeof(uint8_t)); + const auto entry = DBEntry::DBEntryFromSlice(entry_slice); local_miss_elements_[j + (i * chunk_size_)] = entry; } } @@ -185,30 +187,30 @@ void QueryServiceClient::FetchFullDB() { const auto tag = local_backup_set_groups_[i].sets[k].get().tag; const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); const auto offset = tmp & (chunk_size_ - 1); - DBEntryXorFromRaw( - &local_backup_set_groups_[i].sets[k].get().parityAfterPunct, - &chunk[offset * DBEntryLength]); + local_backup_set_groups_[i].sets[k].get().parityAfterPuncture.XorFromRaw( + &chunk[offset * entry_size_]); } // store the replacement - std::mt19937_64 rng(yacl::crypto::FastRandU64()); + yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { // generate a random offset between 0 and ChunkSize - 1 - const auto offset = rng() & (chunk_size_ - 1); + const auto offset = prg() & (chunk_size_ - 1); local_replacement_groups_[i].indices[k] = offset + i * chunk_size_; - std::array entry_slice{}; - std::memcpy(entry_slice.data(), &chunk[offset * DBEntryLength], - DBEntryLength * sizeof(uint64_t)); - local_replacement_groups_[i].value[k] = DBEntryFromSlice(entry_slice); + std::vector entry_slice(entry_size_); + std::memcpy(entry_slice.data(), &chunk[offset * entry_size_], + entry_size_ * sizeof(uint8_t)); + local_replacement_groups_[i].value[k] = + DBEntry::DBEntryFromSlice(entry_slice); } } } void QueryServiceClient::SendDummySet() const { - std::mt19937_64 rng(yacl::crypto::FastRandU64()); + yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); std::vector randSet(set_size_); for (uint64_t i = 0; i < set_size_; i++) { - randSet[i] = rng() % chunk_size_ + i * chunk_size_; + randSet[i] = prg() % chunk_size_ + i * chunk_size_; } // send the random dummy set to the server @@ -249,7 +251,7 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { } } - DBEntry xVal = ZeroEntry(); + DBEntry xVal = DBEntry::ZeroEntry(entry_size_); if (hitSetId == std::numeric_limits::max()) { if (local_miss_elements_.find(x) == local_miss_elements_.end()) { @@ -301,9 +303,9 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { const auto& parity = std::get<0>(parityQueryResponse); // recover the answer - xVal = primary_sets_[hitSetId].parity; // the parity of the hit set - DBEntryXorFromRaw(&xVal, parity.data()); // xor the parity of the edited set - DBEntryXor(&xVal, &repVal); // xor the replacement value + xVal = primary_sets_[hitSetId].parity; // the parity of the hit set + xVal.XorFromRaw(parity.data()); // xor the parity of the edited set + xVal.Xor(repVal); // xor the replacement value // update the local cache local_cache_[x] = xVal; @@ -319,9 +321,10 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { primary_sets_[hitSetId].tag = local_backup_set_groups_[chunkId].sets[consumed].get().tag; // backup set doesn't XOR the chunk(x)-th element in preparation - DBEntryXor( - &xVal, - &local_backup_set_groups_[chunkId].sets[consumed].get().parityAfterPunct); + xVal.Xor(local_backup_set_groups_[chunkId] + .sets[consumed] + .get() + .parityAfterPuncture); primary_sets_[hitSetId].parity = xVal; primary_sets_[hitSetId].isProgrammed = true; // for load balancing, the chunk(x)-th element differs from the one expanded diff --git a/experimental/pir/piano/client.h b/experimental/pir/piano/client.h index 973a13dc..d53fdf11 100644 --- a/experimental/pir/piano/client.h +++ b/experimental/pir/piano/client.h @@ -2,73 +2,64 @@ #include -#include #include -#include -#include +#include #include "experimental/pir/piano/serialize.h" #include "experimental/pir/piano/util.h" -#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/prg.h" #include "yacl/link/context.h" namespace pir::piano { -class LocalSet { - public: - uint32_t tag; // the tag of the set - DBEntry parity; - uint64_t - programmedPoint; // identifier for the element replaced after refresh, - // differing from those expanded by PRFEval - bool isProgrammed; - - LocalSet(const uint32_t tag, const DBEntry& parity, - const uint64_t programmed_point, const bool is_programmed) +struct LocalSet { + LocalSet(const uint32_t tag, DBEntry parity, const uint64_t programmed_point, + const bool is_programmed) : tag(tag), - parity(parity), + parity(std::move(parity)), programmedPoint(programmed_point), isProgrammed(is_programmed) {} + + uint32_t tag; // the tag of the set + DBEntry parity; + uint64_t programmedPoint; // identifier for the element replaced after + // refresh differing from those expanded by PRFEval + bool isProgrammed; }; -class LocalBackupSet { - public: - uint32_t tag; - DBEntry parityAfterPunct; +struct LocalBackupSet { + LocalBackupSet(const uint32_t tag, DBEntry parity_after_puncture) + : tag(tag), parityAfterPuncture(std::move(parity_after_puncture)) {} - LocalBackupSet(const uint32_t tag, const DBEntry& parity_after_punct) - : tag(tag), parityAfterPunct(parity_after_punct) {} + uint32_t tag; + DBEntry parityAfterPuncture; }; -class LocalBackupSetGroup { - public: - uint64_t consumed; - std::vector> sets; - +struct LocalBackupSetGroup { LocalBackupSetGroup( const uint64_t consumed, const std::vector>& sets) : consumed(consumed), sets(sets) {} -}; -class LocalReplacementGroup { - public: uint64_t consumed; - std::vector indices; - std::vector value; + std::vector> sets; +}; +struct LocalReplacementGroup { LocalReplacementGroup(const uint64_t consumed, const std::vector& indices, const std::vector& value) : consumed(consumed), indices(indices), value(value) {} + + uint64_t consumed; + std::vector indices; + std::vector value; }; class QueryServiceClient { public: - static constexpr uint64_t FailureProbLog2 = 40; - uint64_t totalQueryNum{}; - - QueryServiceClient(uint64_t db_size, uint64_t thread_num, + QueryServiceClient(uint64_t entry_num, uint64_t thread_num, + uint64_t entry_size, std::shared_ptr context); void Initialize(); @@ -78,24 +69,28 @@ class QueryServiceClient { DBEntry OnlineSingleQuery(uint64_t x); std::vector OnlineMultipleQueries( const std::vector& queries); + uint64_t getTotalQueryNumber() const { return total_query_num_; }; private: - uint64_t db_size_; + static constexpr uint64_t kFailureProbLog2 = 40; + uint64_t total_query_num_{}; + uint64_t entry_num_; uint64_t thread_num_; std::shared_ptr context_; uint64_t chunk_size_{}; uint64_t set_size_{}; + uint64_t entry_size_{}; uint64_t primary_set_num_{}; uint64_t backup_set_num_per_chunk_{}; uint64_t total_backup_set_num_{}; - PrfKey master_key_{}; + uint128_t master_key_{}; yacl::crypto::AES_KEY long_key_{}; std::vector primary_sets_; std::vector local_backup_sets_; - std::map local_cache_; - std::map local_miss_elements_; + std::unordered_map local_cache_; + std::unordered_map local_miss_elements_; std::vector local_backup_set_groups_; std::vector local_replacement_groups_; }; diff --git a/experimental/pir/piano/piano.proto b/experimental/pir/piano/piano.proto index 514db0e6..c7603f9b 100644 --- a/experimental/pir/piano/piano.proto +++ b/experimental/pir/piano/piano.proto @@ -9,7 +9,7 @@ message FetchFullDbMsg { message DbChunk { uint64 chunk_id = 1; uint64 chunk_size = 2; - repeated uint64 chunks = 3; + bytes chunks = 3; } message SetParityQueryMsg { @@ -18,7 +18,7 @@ message SetParityQueryMsg { } message SetParityQueryResponse { - repeated uint64 parity = 1; + bytes parity = 1; uint64 server_compute_time = 2; } diff --git a/experimental/pir/piano/piano_benchmark.cc b/experimental/pir/piano/piano_benchmark.cc index 53526f3c..9ae168c3 100644 --- a/experimental/pir/piano/piano_benchmark.cc +++ b/experimental/pir/piano/piano_benchmark.cc @@ -15,28 +15,50 @@ namespace { std::vector GenerateQueries(const uint64_t query_num, - const uint64_t db_size) { + const uint64_t entry_num) { std::vector queries; queries.reserve(query_num); - std::mt19937_64 rng(yacl::crypto::FastRandU64()); + yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t q = 0; q < query_num; ++q) { - queries.push_back(rng() % db_size); + queries.push_back(prg() % entry_num); } return queries; } -std::vector CreateDatabase(const uint64_t db_size, - const uint64_t db_seed) { - const auto [ChunkSize, SetSize] = pir::piano::GenParams(db_size); - std::vector DB; - DB.assign(ChunkSize * SetSize * pir::piano::DBEntryLength, 0); +std::vector FNVHash(uint64_t key) { + constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; + uint64_t hash = FNV_offset_basis; - for (uint64_t i = 0; i < DB.size() / pir::piano::DBEntryLength; ++i) { - auto entry = pir::piano::GenDBEntry(db_seed, i); - std::memcpy(&DB[i * pir::piano::DBEntryLength], entry.data(), - pir::piano::DBEntryLength * sizeof(uint64_t)); + for (int i = 0; i < 8; ++i) { + constexpr uint64_t FNV_prime = 1099511628211ULL; + const auto byte = static_cast(key & 0xFF); + hash ^= static_cast(byte); + hash *= FNV_prime; + key >>= 8; + } + + std::vector hash_bytes(8); + for (size_t i = 0; i < 8; ++i) { + hash_bytes[i] = static_cast((hash >> (i * 8)) & 0xFF); + } + + return hash_bytes; +} + +std::vector CreateDatabase(const uint64_t entry_size, + const uint64_t entry_num, + const uint64_t db_seed) { + const auto [ChunkSize, SetSize] = pir::piano::GenParams(entry_num); + std::vector DB; + DB.assign(ChunkSize * SetSize * entry_size, 0); + + for (uint64_t i = 0; i < DB.size() / entry_size; ++i) { + auto entry = + pir::piano::DBEntry::GenDBEntry(entry_size, db_seed, i, FNVHash); + std::memcpy(&DB[i * entry_size], entry.data().data(), + entry_size * sizeof(uint8_t)); } return DB; @@ -44,18 +66,21 @@ std::vector CreateDatabase(const uint64_t db_size, void SetupAndRunServer( const std::shared_ptr& server_context, - const uint64_t db_size, std::promise& exit_signal, - std::vector& db) { - const auto [ChunkSize, SetSize] = pir::piano::GenParams(db_size); - pir::piano::QueryServiceServer server(db, server_context, SetSize, ChunkSize); + const uint64_t entry_size, const uint64_t entry_num, + std::promise& exit_signal, std::vector& db) { + const auto [ChunkSize, SetSize] = pir::piano::GenParams(entry_num); + pir::piano::QueryServiceServer server(db, server_context, SetSize, ChunkSize, + entry_size); server.Start(exit_signal.get_future()); } std::vector SetupAndRunClient( - const uint64_t db_size, const uint64_t thread_num, + const uint64_t entry_num, const uint64_t thread_num, + const uint64_t entry_size, const std::shared_ptr& client_context, const std::vector& queries) { - pir::piano::QueryServiceClient client(db_size, thread_num, client_context); + pir::piano::QueryServiceClient client(entry_num, thread_num, entry_size, + client_context); client.FetchFullDB(); return client.OnlineMultipleQueries(queries); } @@ -65,8 +90,9 @@ std::vector SetupAndRunClient( static void BM_PianoPir(benchmark::State& state) { for (auto _ : state) { state.PauseTiming(); - uint64_t db_size = state.range(0) / sizeof(pir::piano::DBEntry); - const uint64_t query_num = state.range(1); + const uint64_t entry_size = state.range(0); + const uint64_t entry_num = state.range(1) / entry_size / CHAR_BIT; + const uint64_t query_num = state.range(2); constexpr uint64_t db_seed = 2315127; uint64_t thread_num = 8; @@ -74,18 +100,18 @@ static void BM_PianoPir(benchmark::State& state) { const auto contexts = yacl::link::test::SetupWorld(kWorldSize); yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); - auto db = CreateDatabase(db_size, db_seed); - auto queries = GenerateQueries(query_num, db_size); + auto db = CreateDatabase(entry_size, entry_num, db_seed); + auto queries = GenerateQueries(query_num, entry_num); state.ResumeTiming(); std::promise exitSignal; auto server_future = - std::async(std::launch::async, SetupAndRunServer, contexts[0], db_size, - std::ref(exitSignal), std::ref(db)); + std::async(std::launch::async, SetupAndRunServer, contexts[0], + entry_size, entry_num, std::ref(exitSignal), std::ref(db)); auto client_future = - std::async(std::launch::async, SetupAndRunClient, db_size, thread_num, - contexts[1], std::cref(queries)); + std::async(std::launch::async, SetupAndRunClient, entry_num, thread_num, + entry_size, contexts[1], std::cref(queries)); auto results = client_future.get(); exitSignal.set_value(); @@ -93,10 +119,9 @@ static void BM_PianoPir(benchmark::State& state) { } } -// [1m, 16m, 64m, 128m] BENCHMARK(BM_PianoPir) ->Unit(benchmark::kMillisecond) - ->Args({1 << 20, 1000}) - ->Args({16 << 20, 1000}) - ->Args({64 << 20, 1000}) - ->Args({128 << 20, 1000}); + ->Args({4, 1 << 20, 1000}) + ->Args({4, 2 << 20, 1000}) + ->Args({8, 1 << 20, 1000}) + ->Args({8, 2 << 20, 1000}); diff --git a/experimental/pir/piano/piano_test.cc b/experimental/pir/piano/piano_test.cc index f57b4d26..50a72273 100644 --- a/experimental/pir/piano/piano_test.cc +++ b/experimental/pir/piano/piano_test.cc @@ -17,6 +17,7 @@ #include "yacl/link/test_util.h" struct TestParams { + uint64_t entry_size; uint64_t db_size; uint64_t db_seed; uint64_t thread_num; @@ -27,13 +28,13 @@ struct TestParams { namespace pir::piano { std::vector GenerateQueries(const uint64_t query_num, - const uint64_t db_size) { + const uint64_t entry_num) { std::vector queries; queries.reserve(query_num); - std::mt19937_64 rng(yacl::crypto::FastRandU64()); + yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t q = 0; q < query_num; ++q) { - queries.push_back(rng() % db_size); + queries.push_back(prg() % entry_num); } return queries; @@ -45,13 +46,34 @@ std::vector RunClient(QueryServiceClient& client, return client.OnlineMultipleQueries(queries); } +std::vector FNVHash(uint64_t key) { + constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; + uint64_t hash = FNV_offset_basis; + + for (int i = 0; i < 8; ++i) { + constexpr uint64_t FNV_prime = 1099511628211ULL; + const auto byte = static_cast(key & 0xFF); + hash ^= static_cast(byte); + hash *= FNV_prime; + key >>= 8; + } + + std::vector hash_bytes(8); + for (size_t i = 0; i < 8; ++i) { + hash_bytes[i] = static_cast((hash >> (i * 8)) & 0xFF); + } + + return hash_bytes; +} + std::vector getResults(const std::vector& queries, const TestParams& params) { std::vector expected_results; expected_results.reserve(queries.size()); for (const auto& x : queries) { - expected_results.push_back(GenDBEntry(params.db_seed, x)); + expected_results.push_back( + DBEntry::GenDBEntry(params.entry_size, params.db_seed, x, FNVHash)); } return expected_results; @@ -62,33 +84,37 @@ class PianoTest : public testing::TestWithParam {}; TEST_P(PianoTest, Works) { auto params = GetParam(); constexpr int kWorldSize = 2; + uint64_t entry_num = params.db_size / params.entry_size / CHAR_BIT; const auto contexts = yacl::link::test::SetupWorld(kWorldSize); - SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", - params.db_size, DBEntrySize, - params.db_size * DBEntrySize / 1024 / 1024); + SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", entry_num, + params.entry_size, entry_num * params.entry_size / 1024 / 1024); - auto [ChunkSize, SetSize] = GenParams(params.db_size); + auto [ChunkSize, SetSize] = GenParams(entry_num); SPDLOG_INFO("Chunk Size: %lu, Set Size: %lu\n", ChunkSize, SetSize); - std::vector DB; - DB.assign(ChunkSize * SetSize * DBEntryLength, 0); + std::vector DB; + DB.assign(ChunkSize * SetSize * params.entry_size, 0); SPDLOG_INFO("DB Real N: %lu\n", DB.size()); - for (uint64_t i = 0; i < DB.size() / DBEntryLength; ++i) { - auto entry = GenDBEntry(params.db_seed, i); - std::memcpy(&DB[i * DBEntryLength], entry.data(), - DBEntryLength * sizeof(uint64_t)); + for (uint64_t i = 0; i < DB.size() / params.entry_size; ++i) { + auto entry = + DBEntry::GenDBEntry(params.entry_size, params.db_seed, i, FNVHash); + std::memcpy(&DB[i * params.entry_size], entry.data().data(), + params.entry_size * sizeof(uint8_t)); } - QueryServiceClient client(params.db_size, params.thread_num, contexts[1]); + QueryServiceClient client(entry_num, params.thread_num, params.entry_size, + contexts[1]); - const auto actual_query_num = - params.is_total_query_num ? client.totalQueryNum : params.query_num; - auto queries = GenerateQueries(actual_query_num, DB.size()); + const auto actual_query_num = params.is_total_query_num + ? client.getTotalQueryNumber() + : params.query_num; + const auto queries = GenerateQueries(actual_query_num, entry_num); yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); - QueryServiceServer server(DB, contexts[0], SetSize, ChunkSize); + QueryServiceServer server(DB, contexts[0], SetSize, ChunkSize, + params.entry_size); std::promise exitSignal; std::future futureObj = exitSignal.get_future(); @@ -98,11 +124,11 @@ TEST_P(PianoTest, Works) { auto client_future = std::async(std::launch::async, RunClient, std::ref(client), std::cref(queries)); - auto results = client_future.get(); - auto expected_results = getResults(queries, params); + const auto results = client_future.get(); + const auto expected_results = getResults(queries, params); for (size_t i = 0; i < results.size(); ++i) { - EXPECT_EQ(results[i], expected_results[i]) + EXPECT_EQ(results[i].data(), expected_results[i].data()) << "Mismatch at index " << queries[i]; } @@ -110,10 +136,10 @@ TEST_P(PianoTest, Works) { server_future.get(); } -// [8m, 128m, 256m] +// [8m, 128m, 256m] units are in bits INSTANTIATE_TEST_SUITE_P( PianoTestInstances, PianoTest, - ::testing::Values(TestParams{131072, 1211212, 8, 1000, false}, - TestParams{2097152, 6405285, 8, 1000, false}, - TestParams{4194304, 7539870, 16, 1000, false})); + ::testing::Values(TestParams{8, 8 << 20, 1211212, 8, 1000, false}, + TestParams{8, 128 << 20, 6405285, 8, 1000, false}, + TestParams{8, 256 << 20, 7539870, 16, 1000, false})); } // namespace pir::piano diff --git a/experimental/pir/piano/serialize.h b/experimental/pir/piano/serialize.h index 28a7a537..0e59c52d 100644 --- a/experimental/pir/piano/serialize.h +++ b/experimental/pir/piano/serialize.h @@ -30,23 +30,21 @@ inline uint64_t DeserializeFetchFullDBMsg(const yacl::Buffer& buf) { inline yacl::Buffer SerializeDBChunk(const uint64_t chunk_id, const uint64_t chunk_size, - const std::vector& chunk) { + const std::vector& chunk) { DbChunk proto; proto.set_chunk_id(chunk_id); proto.set_chunk_size(chunk_size); - for (const auto& val : chunk) { - proto.add_chunks(val); - } + proto.set_chunks(chunk.data(), chunk.size()); yacl::Buffer buf(proto.ByteSizeLong()); proto.SerializeToArray(buf.data(), buf.size()); return buf; } -inline std::tuple> DeserializeDBChunk( +inline std::tuple> DeserializeDBChunk( const yacl::Buffer& buf) { DbChunk proto; proto.ParseFromArray(buf.data(), buf.size()); - std::vector chunk(proto.chunks().begin(), proto.chunks().end()); + std::vector chunk(proto.chunks().begin(), proto.chunks().end()); return {proto.chunk_id(), proto.chunk_size(), chunk}; } @@ -76,22 +74,20 @@ inline std::pair> DeserializeSetParityQueryMsg( } inline yacl::Buffer SerializeSetParityQueryResponse( - const std::vector& parity, const uint64_t server_compute_time) { + const std::vector& parity, const uint64_t server_compute_time) { SetParityQueryResponse proto; - for (const auto& p : parity) { - proto.add_parity(p); - } + proto.set_parity(parity.data(), parity.size()); proto.set_server_compute_time(server_compute_time); yacl::Buffer buf(proto.ByteSizeLong()); proto.SerializeToArray(buf.data(), buf.size()); return buf; } -inline std::pair, uint64_t> +inline std::pair, uint64_t> DeserializeSetParityQueryResponse(const yacl::Buffer& buf) { SetParityQueryResponse proto; proto.ParseFromArray(buf.data(), buf.size()); - std::vector parity(proto.parity().begin(), proto.parity().end()); + std::vector parity(proto.parity().begin(), proto.parity().end()); return {parity, proto.server_compute_time()}; } diff --git a/experimental/pir/piano/server.cc b/experimental/pir/piano/server.cc index 9d2516cf..3bfc7e1a 100644 --- a/experimental/pir/piano/server.cc +++ b/experimental/pir/piano/server.cc @@ -3,12 +3,14 @@ namespace pir::piano { QueryServiceServer::QueryServiceServer( - std::vector& db, std::shared_ptr context, - const uint64_t set_size, const uint64_t chunk_size) + std::vector& db, std::shared_ptr context, + const uint64_t set_size, const uint64_t chunk_size, + const uint64_t entry_size) : db_(std::move(db)), context_(std::move(context)), set_size_(set_size), - chunk_size_(chunk_size) {} + chunk_size_(chunk_size), + entry_size_(entry_size) {} void QueryServiceServer::Start(const std::future& stop_signal) { while (stop_signal.wait_for(std::chrono::milliseconds(1)) == @@ -48,9 +50,9 @@ void QueryServiceServer::ProcessFetchFullDB() { for (uint64_t i = 0; i < set_size_; ++i) { const uint64_t down = i * chunk_size_; uint64_t up = (i + 1) * chunk_size_; - up = std::min(up, static_cast(db_.size())); - std::vector chunk(db_.begin() + down * DBEntryLength, - db_.begin() + up * DBEntryLength); + up = std::min(up, db_.size() / entry_size_); + std::vector chunk(db_.begin() + down * entry_size_, + db_.begin() + up * entry_size_); auto chunk_buf = SerializeDBChunk(i, chunk.size(), chunk); try { @@ -62,43 +64,36 @@ void QueryServiceServer::ProcessFetchFullDB() { } } -std::pair, uint64_t> +std::pair, uint64_t> QueryServiceServer::ProcessSetParityQuery( const std::vector& indices) { const auto start = std::chrono::high_resolution_clock::now(); - std::vector parity = HandleSetParityQuery(indices); + std::vector parity = HandleSetParityQuery(indices); const auto end = std::chrono::high_resolution_clock::now(); const auto duration = std::chrono::duration_cast(end - start).count(); return {parity, duration}; } -DBEntry QueryServiceServer::DBAccess(const uint64_t id) { - if (id < db_.size()) { - if (id * DBEntryLength + DBEntryLength > db_.size()) { - SPDLOG_ERROR("DBAccess: id {} out of range", id); - } - std::array slice{}; - std::copy(db_.begin() + id * DBEntryLength, - db_.begin() + (id + 1) * DBEntryLength, slice.begin()); - return DBEntryFromSlice(slice); - } - DBEntry ret; - ret.fill(0); - return ret; -} - -std::vector QueryServiceServer::HandleSetParityQuery( +std::vector QueryServiceServer::HandleSetParityQuery( const std::vector& indices) { - DBEntry parity = ZeroEntry(); + DBEntry parity = DBEntry::ZeroEntry(entry_size_); for (const auto& index : indices) { DBEntry entry = DBAccess(index); - DBEntryXor(&parity, &entry); + parity.Xor(entry); } + return parity.data(); +} - std::vector ret(DBEntryLength); - std::copy(parity.begin(), parity.end(), ret.begin()); - return ret; +DBEntry QueryServiceServer::DBAccess(const uint64_t id) { + if (const size_t num_entries = db_.size() / entry_size_; id < num_entries) { + std::vector slice(entry_size_); + std::copy(db_.begin() + id * entry_size_, + db_.begin() + (id + 1) * entry_size_, slice.begin()); + return DBEntry::DBEntryFromSlice(slice); + } + SPDLOG_ERROR("DBAccess: id {} out of range", id); + return DBEntry::ZeroEntry(entry_size_); } } // namespace pir::piano diff --git a/experimental/pir/piano/server.h b/experimental/pir/piano/server.h index 04650afa..40096215 100644 --- a/experimental/pir/piano/server.h +++ b/experimental/pir/piano/server.h @@ -2,7 +2,6 @@ #include -#include #include #include #include @@ -19,9 +18,10 @@ class QueryServiceServer { public: // Constructor: initializes the server with a database, context, set_size, and // chunk_size - QueryServiceServer(std::vector& db, + QueryServiceServer(std::vector& db, std::shared_ptr context, - uint64_t set_size, uint64_t chunk_size); + uint64_t set_size, uint64_t chunk_size, + uint64_t entry_size); // Starts the server to handle incoming requests void Start(const std::future& stop_signal); @@ -33,7 +33,7 @@ class QueryServiceServer { void ProcessFetchFullDB(); // Processes a set parity query and returns the parity and server compute time - std::pair, uint64_t> ProcessSetParityQuery( + std::pair, uint64_t> ProcessSetParityQuery( const std::vector& indices); private: @@ -41,13 +41,14 @@ class QueryServiceServer { DBEntry DBAccess(uint64_t id); // Handles a set parity query and returns the parity - std::vector HandleSetParityQuery( + std::vector HandleSetParityQuery( const std::vector& indices); - std::vector db_; // The database + std::vector db_; // The database std::shared_ptr context_; // The communication context uint64_t set_size_; // The size of the set uint64_t chunk_size_; // The size of each chunk + uint64_t entry_size_; // The size of database entry }; } // namespace pir::piano diff --git a/experimental/pir/piano/util.cc b/experimental/pir/piano/util.cc index cab249db..42e955f3 100644 --- a/experimental/pir/piano/util.cc +++ b/experimental/pir/piano/util.cc @@ -2,33 +2,11 @@ namespace pir::piano { -uint128_t BytesToUint128(const std::string& bytes) { - if (bytes.size() != 16) { - SPDLOG_WARN("Bytes size must be 16 for uint128_t conversion."); - } - - uint128_t result = 0; - std::memcpy(&result, bytes.data(), 16); - return result; -} - -std::string Uint128ToBytes(const uint128_t value) { - std::string bytes(16, 0); - std::memcpy(bytes.data(), &value, 16); - return bytes; -} +uint128_t SecureRandKey() { return yacl::crypto::SecureRandU128(); } -PrfKey128 RandKey128(std::mt19937_64& rng) { - const uint64_t lo = rng(); - const uint64_t hi = rng(); - return yacl::MakeUint128(hi, lo); -} - -PrfKey RandKey(std::mt19937_64& rng) { return RandKey128(rng); } - -uint64_t PRFEval128(const PrfKey128* key, const uint64_t x) { +uint64_t PRFEval(const uint128_t key, const uint64_t x) { yacl::crypto::AES_KEY aes_key; - AES_set_encrypt_key(*key, &aes_key); + AES_set_encrypt_key(key, &aes_key); const auto src_block = static_cast(x); std::vector plain_blocks(1); @@ -40,79 +18,9 @@ uint64_t PRFEval128(const PrfKey128* key, const uint64_t x) { return static_cast(cipher_blocks[0]); } -uint64_t PRFEval(const PrfKey* key, const uint64_t x) { - return PRFEval128(key, x); -} - -void DBEntryXor(DBEntry* dst, const DBEntry* src) { - for (size_t i = 0; i < DBEntryLength; ++i) { - (*dst)[i] ^= (*src)[i]; - } -} - -void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src) { - for (size_t i = 0; i < DBEntryLength; ++i) { - (*dst)[i] ^= src[i]; - } -} - -bool EntryIsEqual(const DBEntry& a, const DBEntry& b) { - for (size_t i = 0; i < DBEntryLength; ++i) { - if (a[i] != b[i]) { - return false; - } - } - return true; -} - -DBEntry RandDBEntry(std::mt19937_64& rng) { - DBEntry entry; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = rng(); - } - return entry; -} - -uint64_t DefaultHash(uint64_t key) { - constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; - uint64_t hash = FNV_offset_basis; - for (int i = 0; i < 8; ++i) { - constexpr uint64_t FNV_prime = 1099511628211ULL; - const auto byte = static_cast(key & 0xFF); - hash ^= static_cast(byte); - hash *= FNV_prime; - key >>= 8; - } - return hash; -} - -DBEntry GenDBEntry(const uint64_t key, const uint64_t id) { - DBEntry entry; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = DefaultHash((key ^ id) + i); - } - return entry; -} - -DBEntry ZeroEntry() { - DBEntry entry = {}; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = 0; - } - return entry; -} - -DBEntry DBEntryFromSlice(const std::array& s) { - DBEntry entry; - for (size_t i = 0; i < DBEntryLength; ++i) { - entry[i] = s[i]; - } - return entry; -} - // Generate ChunkSize and SetSize -std::pair GenParams(const uint64_t db_size) { - const double targetChunkSize = 2 * std::sqrt(static_cast(db_size)); +std::pair GenParams(const uint64_t entry_num) { + const double targetChunkSize = 2 * std::sqrt(static_cast(entry_num)); uint64_t ChunkSize = 1; // Ensure ChunkSize is a power of 2 and not smaller than targetChunkSize @@ -120,16 +28,16 @@ std::pair GenParams(const uint64_t db_size) { ChunkSize *= 2; } - uint64_t SetSize = (db_size + ChunkSize - 1) / ChunkSize; + uint64_t SetSize = (entry_num + ChunkSize - 1) / ChunkSize; // Round up to the next multiple of 4 SetSize = (SetSize + 3) / 4 * 4; return {ChunkSize, SetSize}; } -yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key) { +yacl::crypto::AES_KEY GetLongKey(const uint128_t key) { yacl::crypto::AES_KEY aes_key; - AES_set_encrypt_key(*key, &aes_key); + AES_set_encrypt_key(key, &aes_key); return aes_key; } diff --git a/experimental/pir/piano/util.h b/experimental/pir/piano/util.h index 48406423..ba13813d 100644 --- a/experimental/pir/piano/util.h +++ b/experimental/pir/piano/util.h @@ -2,75 +2,93 @@ #include -#include #include -#include +#include #include #include +#include "yacl/base/int128.h" #include "yacl/crypto/aes/aes_intrinsics.h" +#include "yacl/crypto/rand/rand.h" namespace pir::piano { -constexpr size_t DBEntrySize = 8; // has to be a multiple of 8 -constexpr size_t DBEntryLength = DBEntrySize / 8; - -using PrfKey128 = uint128_t; -using DBEntry = std::array; -using PrfKey = PrfKey128; - -uint128_t BytesToUint128(const std::string& bytes); - -std::string Uint128ToBytes(uint128_t value); - -// Generates a random 128-bit key using the provided RNG -PrfKey128 RandKey128(std::mt19937_64& rng); - -// Generates a random PRF key -PrfKey RandKey(std::mt19937_64& rng); - -// Evaluates PRF using 128-bit key and returns a 64-bit result -uint64_t PRFEval128(const PrfKey128* key, uint64_t x); - -// Evaluates PRF using a general PrfKey and returns a 64-bit result -uint64_t PRFEval(const PrfKey* key, uint64_t x); - -// XOR two DBEntry structures -void DBEntryXor(DBEntry* dst, const DBEntry* src); - -// XOR a DBEntry with raw uint64_t data -void DBEntryXorFromRaw(DBEntry* dst, const uint64_t* src); - -// Compare two DBEntry structures for equality -bool EntryIsEqual(const DBEntry& a, const DBEntry& b); - -// Generate a random DBEntry using the provided RNG -DBEntry RandDBEntry(std::mt19937_64& rng); - -// Default FNV hash implementation for 64-bit keys -uint64_t DefaultHash(uint64_t key); - -// Generate a DBEntry based on a key and ID -DBEntry GenDBEntry(uint64_t key, uint64_t id); +class DBEntry { + public: + DBEntry() = default; + + // entry_size represents the number of bytes in the DBEntry + explicit DBEntry(const size_t entry_size) + : k_length_(entry_size), data_(entry_size, 0) {} + + explicit DBEntry(const std::vector& data) + : k_length_(data.size()), data_(data) {} + + // Accessor for the underlying data + std::vector& data() { return data_; } + [[nodiscard]] const std::vector& data() const { return data_; } + + // XOR operations + void Xor(const DBEntry& other) { + for (size_t i = 0; i < k_length_; ++i) { + data_[i] ^= other.data_[i]; + } + } + + void XorFromRaw(const uint8_t* src) { + for (size_t i = 0; i < k_length_; ++i) { + data_[i] ^= src[i]; + } + } + + // Static method to generate a zero-filled DBEntry + static DBEntry ZeroEntry(const size_t entry_size) { + return DBEntry(entry_size); + } + + // Generate a DBEntry based on a key and ID using a custom hash function + static DBEntry GenDBEntry( + const size_t entry_size, const uint64_t key, const uint64_t id, + const std::function(uint64_t)>& hash_func) { + DBEntry entry(entry_size); + const std::vector hash = hash_func(key ^ id); + for (size_t i = 0; i < entry_size; ++i) { + if (i < hash.size()) { + entry.data_[i] = hash[i]; + } else { + entry.data_[i] = 0; + } + } + return entry; + } + + // Convert a slice (vector) into a DBEntry structure + static DBEntry DBEntryFromSlice(const std::vector& s) { + return DBEntry(s); + } + + private: + size_t k_length_{}; + std::vector data_; +}; -// Generate a zero-filled DBEntry -DBEntry ZeroEntry(); +// Generate secure master key +uint128_t SecureRandKey(); -// Convert a slice (array) into a DBEntry structure -DBEntry DBEntryFromSlice(const std::array& s); +// Evaluates PRF using a 128-bit key and returns a 64-bit result +uint64_t PRFEval(uint128_t key, uint64_t x); // Generate parameters for ChunkSize and SetSize -std::pair GenParams(uint64_t db_size); +std::pair GenParams(uint64_t entry_num); // Returns a long key (AES expanded key) for PRF evaluation -yacl::crypto::AES_KEY GetLongKey(const PrfKey128* key); +yacl::crypto::AES_KEY GetLongKey(uint128_t key); // PRF evaluation with a long key and tag, returns a 64-bit result uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, uint32_t tag, uint64_t x); -class PRSetWithShortTag { - public: +struct PRSetWithShortTag { uint32_t Tag; // Expands the set with a long key and tag From 53970100fbcf31bbd8bfa68f4c456beb2dae64ac Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Thu, 5 Dec 2024 10:56:48 +0800 Subject: [PATCH 08/11] Format piano.proto with clang-format --- experimental/pir/piano/piano.proto | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/experimental/pir/piano/piano.proto b/experimental/pir/piano/piano.proto index c7603f9b..058de506 100644 --- a/experimental/pir/piano/piano.proto +++ b/experimental/pir/piano/piano.proto @@ -3,28 +3,28 @@ syntax = "proto3"; package pir.piano; message FetchFullDbMsg { - uint64 dummy = 1; + uint64 dummy = 1; } message DbChunk { - uint64 chunk_id = 1; - uint64 chunk_size = 2; - bytes chunks = 3; + uint64 chunk_id = 1; + uint64 chunk_size = 2; + bytes chunks = 3; } message SetParityQueryMsg { - uint64 set_size = 1; - repeated uint64 indices = 2; + uint64 set_size = 1; + repeated uint64 indices = 2; } message SetParityQueryResponse { - bytes parity = 1; - uint64 server_compute_time = 2; + bytes parity = 1; + uint64 server_compute_time = 2; } message QueryRequest { - oneof request { - FetchFullDbMsg fetch_full_db = 1; - SetParityQueryMsg set_parity_query = 2; - } + oneof request { + FetchFullDbMsg fetch_full_db = 1; + SetParityQueryMsg set_parity_query = 2; + } } From 9f900cb563ec61f295419bf70887d76454fbcd35 Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Wed, 18 Dec 2024 18:04:01 +0800 Subject: [PATCH 09/11] Unified naming conventions, added meaningful comments, and verified code correctness with test cases --- experimental/pir/piano/README.md | 69 +++-- experimental/pir/piano/client.cc | 326 +++++++++++----------- experimental/pir/piano/client.h | 79 ++++-- experimental/pir/piano/piano.proto | 23 +- experimental/pir/piano/piano_benchmark.cc | 129 ++++----- experimental/pir/piano/piano_test.cc | 125 ++++----- experimental/pir/piano/serialize.h | 69 ++--- experimental/pir/piano/server.cc | 104 +++---- experimental/pir/piano/server.h | 56 ++-- experimental/pir/piano/util.cc | 73 ++--- experimental/pir/piano/util.h | 64 +++-- 11 files changed, 566 insertions(+), 551 deletions(-) diff --git a/experimental/pir/piano/README.md b/experimental/pir/piano/README.md index 65cc6d99..ea315a8c 100644 --- a/experimental/pir/piano/README.md +++ b/experimental/pir/piano/README.md @@ -1,28 +1,59 @@ -Piano: Extremely Simple, Single-server PIR with Sublinear Server Computation +# Piano PIR: Single-Server Private Information Retrieval with Sublinear Computation -论文地址:https://eprint.iacr.org/2023/452 +## Scheme Parameters and Notation -论文开源实现:https://github.com/wuwuz/Piano-PIR-new +- $\kappa$: Statistical security parameter +- $\lambda$: Computational security parameter +- $\alpha(\kappa)$: Arbitrarily small super-constant function +- $n$: Database size +- $Q = \sqrt{n} \log \kappa \cdot \alpha(\kappa)$: Total number of queries -**方案概括** +## Preprocessing Phase -1. 采用特定于客户端的预处理模型(也称为订阅模型),让每个客户端在预处理期间下载并存储来自服务器的“提示” +### Client-Side Initialization -2. 实现了O(√n)的客户端存储,和O(√n)的在线通信与计算开销(平均到每个查询上) -3. 服务器是诚实且好奇的,恶意服务器也无法侵害隐私,但会导致查询结果出错 -4. 方案包括两个阶段:预处理阶段和在线查询阶段 +1. **Primary Table Generation** + - Sample $M_1 = \sqrt{n} \log \kappa \cdot \alpha(\kappa)$ PRF keys, $\{sk_1, \ldots, sk_{M_1}\} \in \{0,1\}^{\lambda}$ + - Initialize parities $\{p_1, \ldots, p_{M_1}\}$ to zeros -**预处理阶段** +2. **Backup Table Generation** + - For each chunk $j \in \{0, 1, \ldots, \sqrt{n} - 1\}$: + - Sample $M_2 = \log \kappa \cdot \alpha(\kappa)$ PRF keys $\{sk_{j,1}, \ldots, sk_{j,M_2}\}$ + - Initialize chunk-specific parities $\{p_{j,1}, \ldots, p_{j,M_2}\}$ to zeros -1. 服务器将数据库划分为O(√n)个块,客户端**流式**的从服务器获取每块数据,并每次只处理当前块中的元素,包括记录部分数据和计算奇偶校验位 -2. 客户端要存储的“提示”包括三类:主表,替换条目和备份表,主表共有O(√n)个,替换条目和备份表在每个块上要存储 - O(1)个 +### Streaming Database Preprocessing -**在线查询阶段** -1. 客户端查询包含x的主表,将主表中的x替换为替换条目中该块的下一个可用元素,发送序列给服务器。服务器计算序列的奇偶校验位并返回,客户端在本地通过异或操作恢复出DB[x] -2. 主表使用一次后就要丢弃,使用备份表进行替换,同时为了负载均衡,要保留(x,DB[x]) +For each database chunk $DB[j \sqrt{n} : (j+1) \sqrt{n}]$: +- Update primary table parity: for $i \in [M_1]$, $p_i \leftarrow p_i \oplus DB[\text{Set}(sk_i)[j]]$ +- Store replacement entries: sample $M_2$ tuples $(r, DB[r])$ where $r$ is a random index from the current chunk +- Update backup table parity: for $i \in \{0, 1, \ldots, \sqrt{n} - 1\} \setminus \{j\}$ and $k \in [M_2]$, $p_{i,k} \leftarrow p_{i,k} \oplus DB[\text{Set}(sk_{i,k})[j]]$ +- Delete current chunk from local storage -**具体实现** -1. 客户端不存储完整的序列,而是只存储tag,通过msk和PRF扩展出完整的序列 -2. 每个块都有O(1)个备份表,备份表中该块对应的DB[x]没有参与计算奇偶校验位,以实现快速替换 -3. 本地保存查询记录,当有重复查询时在本地查询,并发送随机序列给服务器 +## Online Query Phase + +Query Protocol for Index $x \in \{0, 1, \ldots, n-1\}$ + +1. **Query Execution** + - Find primary table hint $T_i = ((sk_i, x^{\prime}), p_i)$ where $x \in \text{Set}(sk_i, x^{\prime})$ + - Locate chunk $j^* = \text{chunk}(x)$ + - Find first unused replacement entry $(r, DB[r])$ + - Send set $S' = S \setminus \{j^* \to r\}$ to server + - Server returns $q = \bigoplus_{k \in S'} DB[k]$ + - Compute answer $\beta = q \oplus p_i \oplus DB[r]$ + +2. **Table Refresh Mechanism** + - Locate next unused backup entry $(sk_{j^*,k}, p_{j^*,k})$ + - If no entry exists, generate random $sk_{j^*,k}$ with $p_{j^*,k} = 0$ + - Update primary table with new entry: $((sk_{j^*,k}, x), p_{j^*,k} \oplus \beta)$ + +## Theoretical Guarantees + +- **Client Storage**: $O(\sqrt{n})$ +- **Server Computation**: $O(\sqrt{n})$ +- **Communication Overhead**: $O(\sqrt{n})$ +- **Query Complexity**: $O(\sqrt{n})$ + +## References + +- **Paper**: [Piano PIR](https://eprint.iacr.org/2023/452) +- **Implementation**: [GitHub Repository](https://github.com/wuwuz/Piano-PIR-new) diff --git a/experimental/pir/piano/client.cc b/experimental/pir/piano/client.cc index d542d9e7..65f20c74 100644 --- a/experimental/pir/piano/client.cc +++ b/experimental/pir/piano/client.cc @@ -2,56 +2,48 @@ namespace pir::piano { -uint64_t primaryNumParam(const double q, const double chunk_size, - const double target) { - const double k = std::ceil((std::log(2) * target) + std::log(q)); - return static_cast(k) * static_cast(chunk_size); -} - -double FailProbBallIntoBins(const uint64_t ball_num, const uint64_t bin_num, - const uint64_t bin_size) { - const double mean = - static_cast(ball_num) / static_cast(bin_num); - const double c = (static_cast(bin_size) / mean) - 1; - // Chernoff bound exp(-(c^2)/(2+c) * mean) - double t = (mean * (c * c) / (2 + c)) * std::log(2); - t -= std::log2(static_cast(bin_num)); - return t; -} - QueryServiceClient::QueryServiceClient( - const uint64_t entry_num, const uint64_t thread_num, - const uint64_t entry_size, std::shared_ptr context) - : entry_num_(entry_num), + std::shared_ptr context, const uint64_t entry_num, + const uint64_t thread_num, const uint64_t entry_size) + : context_(std::move(context)), + entry_num_(entry_num), thread_num_(thread_num), - context_(std::move(context)), entry_size_(entry_size) { Initialize(); InitializeLocalSets(); } void QueryServiceClient::Initialize() { + // Set the computational security parameter to 128 master_key_ = SecureRandKey(); long_key_ = GetLongKey(master_key_); - // Q = sqrt(n) * ln(n) + // Q = sqrt(n) * log(k) * α(κ) + // Maximum number of queries supported by a single preprocessing + // Let α(κ) be any super-constant function, i.e., α(κ) = w(1) + // Chosen log(log(κ)): grows slowly but surely > any constant as κ → ∞ total_query_num_ = static_cast(std::sqrt(static_cast(entry_num_)) * - std::log(static_cast(entry_num_))); + natural_log_k_ * std::log(natural_log_k_)); - std::tie(chunk_size_, set_size_) = GenParams(entry_num_); + std::tie(chunk_size_, set_size_) = GenChunkParams(entry_num_); - primary_set_num_ = - primaryNumParam(static_cast(total_query_num_), - static_cast(chunk_size_), kFailureProbLog2 + 1); - // if localSetNum is not a multiple of thread_num_ then we need to add some - // padding + // M1 = sqrt(n) * log(k) * α(κ) + // The probability that the client cannot find a set that contains the online + // query index is negligible in κ + primary_set_num_ = total_query_num_; + + // if primary_set_num_ is not a multiple of thread_num_ then we need to add + // some padding primary_set_num_ = (primary_set_num_ + thread_num_ - 1) / thread_num_ * thread_num_; - backup_set_num_per_chunk_ = - 3 * static_cast(static_cast(total_query_num_) / - static_cast(set_size_)); + // M2 = log2(k) * log(k) * α(κ) + // The probability that the client runs out of hints in a backup group is + // negligible in κ + backup_set_num_per_chunk_ = static_cast( + static_cast(log2_k_) * natural_log_k_ * std::log(natural_log_k_)); + backup_set_num_per_chunk_ = (backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; @@ -66,20 +58,20 @@ void QueryServiceClient::InitializeLocalSets() { local_backup_sets_.reserve(total_backup_set_num_); local_cache_.clear(); local_miss_elements_.clear(); - uint32_t tagCounter = 0; + uint32_t tag_counter = 0; // Initialize primary_sets_ for (uint64_t j = 0; j < primary_set_num_; j++) { - primary_sets_.emplace_back(tagCounter, DBEntry::ZeroEntry(entry_size_), 0, + primary_sets_.emplace_back(tag_counter, DBEntry::ZeroEntry(entry_size_), 0, false); - tagCounter += 1; + tag_counter += 1; } // Initialize local_backup_sets_ for (uint64_t i = 0; i < total_backup_set_num_; ++i) { - local_backup_sets_.emplace_back(tagCounter, + local_backup_sets_.emplace_back(tag_counter, DBEntry::ZeroEntry(entry_size_)); - tagCounter += 1; + tag_counter += 1; } local_backup_set_groups_.clear(); @@ -89,75 +81,71 @@ void QueryServiceClient::InitializeLocalSets() { // Initialize local_backup_set_groups_ and local_replacement_groups_ for (uint64_t i = 0; i < set_size_; i++) { - std::vector> backupSets; + std::vector> backup_sets; for (uint64_t j = 0; j < backup_set_num_per_chunk_; j++) { - backupSets.emplace_back( + backup_sets.emplace_back( local_backup_sets_[(i * backup_set_num_per_chunk_) + j]); } - LocalBackupSetGroup backupGroup(0, backupSets); - local_backup_set_groups_.emplace_back(std::move(backupGroup)); + LocalBackupSetGroup backup_group(0, backup_sets); + local_backup_set_groups_.emplace_back(std::move(backup_group)); std::vector indices(backup_set_num_per_chunk_); std::vector values(backup_set_num_per_chunk_); - LocalReplacementGroup replacementGroup(0, indices, values); - local_replacement_groups_.emplace_back(std::move(replacementGroup)); + LocalReplacementGroup replacement_group(0, indices, values); + local_replacement_groups_.emplace_back(std::move(replacement_group)); } } void QueryServiceClient::FetchFullDB() { - const auto fetchFullDBMsg = SerializeFetchFullDBMsg(1); - context_->SendAsync(context_->NextRank(), fetchFullDBMsg, "FetchFullDBMsg"); + context_->SendAsync(context_->NextRank(), SerializeFetchFullDB(1), + "FetchFullDB"); for (uint64_t i = 0; i < set_size_; i++) { - auto chunkBuf = context_->Recv(context_->NextRank(), "DBChunk"); - if (chunkBuf.size() == 0) { - break; - } - auto dbChunk = DeserializeDBChunk(chunkBuf); - auto& chunk = std::get<2>(dbChunk); + auto db_chunk = + DeserializeDBChunk(context_->Recv(context_->NextRank(), "DBChunk")); - std::vector hitMap(chunk_size_, false); + std::vector hit_map(chunk_size_, false); // Use multiple threads to parallelize the computation for the chunk std::vector threads; - std::mutex hitMapMutex; + std::mutex hit_map_mutex; - // make sure all sets are covered - const uint64_t perTheadSetNum = + // Make sure all sets are covered + const uint64_t primary_set_per_thread = ((primary_set_num_ + thread_num_ - 1) / thread_num_) + 1; - const uint64_t perThreadBackupNum = + const uint64_t backup_set_per_thread = ((total_backup_set_num_ + thread_num_ - 1) / thread_num_) + 1; for (uint64_t tid = 0; tid < thread_num_; tid++) { - uint64_t startIndex = tid * perTheadSetNum; - uint64_t endIndex = - std::min(startIndex + perTheadSetNum, primary_set_num_); - - uint64_t startIndexBackup = tid * perThreadBackupNum; - uint64_t endIndexBackup = std::min(startIndexBackup + perThreadBackupNum, - total_backup_set_num_); - - threads.emplace_back([&, startIndex, endIndex, startIndexBackup, - endIndexBackup] { - // update the parities for the primary hints - for (uint64_t j = startIndex; j < endIndex; j++) { + uint64_t start_index = tid * primary_set_per_thread; + uint64_t end_index = + std::min(start_index + primary_set_per_thread, primary_set_num_); + + uint64_t start_index_backup = tid * backup_set_per_thread; + uint64_t end_index_backup = std::min( + start_index_backup + backup_set_per_thread, total_backup_set_num_); + + threads.emplace_back([&, start_index, end_index, start_index_backup, + end_index_backup] { + // Update the parities for the primary hints + for (uint64_t j = start_index; j < end_index; j++) { const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, primary_sets_[j].tag, i); const auto offset = tmp & (chunk_size_ - 1); { - std::lock_guard lock(hitMapMutex); - hitMap[offset] = true; + std::lock_guard lock(hit_map_mutex); + hit_map[offset] = true; } - primary_sets_[j].parity.XorFromRaw(&chunk[offset * entry_size_]); + primary_sets_[j].parity.XorFromRaw(&db_chunk[offset * entry_size_]); } - // update the parities for the backup hints - for (uint64_t j = startIndexBackup; j < endIndexBackup; j++) { + // Update the parities for the backup hints + for (uint64_t j = start_index_backup; j < end_index_backup; j++) { const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); const auto offset = tmp & (chunk_size_ - 1); - local_backup_sets_[j].parityAfterPuncture.XorFromRaw( - &chunk[offset * entry_size_]); + local_backup_sets_[j].parity_after_puncture.XorFromRaw( + &db_chunk[offset * entry_size_]); } }); } @@ -172,9 +160,9 @@ void QueryServiceClient::FetchFullDB() { // the local miss cache. Most of the time, the local miss cache will be // empty. for (uint64_t j = 0; j < chunk_size_; j++) { - if (!hitMap[j]) { + if (!hit_map[j]) { std::vector entry_slice(entry_size_); - std::memcpy(entry_slice.data(), &chunk[j * entry_size_], + std::memcpy(entry_slice.data(), &db_chunk[j * entry_size_], entry_size_ * sizeof(uint8_t)); const auto entry = DBEntry::DBEntryFromSlice(entry_slice); local_miss_elements_[j + (i * chunk_size_)] = entry; @@ -182,48 +170,51 @@ void QueryServiceClient::FetchFullDB() { } // For the i-th group of backups, leave the i-th chunk as blank - // To do that, we just xor the i-th chunk's value again + // To do that, we just XOR the i-th chunk's value again for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { const auto tag = local_backup_set_groups_[i].sets[k].get().tag; const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); const auto offset = tmp & (chunk_size_ - 1); - local_backup_set_groups_[i].sets[k].get().parityAfterPuncture.XorFromRaw( - &chunk[offset * entry_size_]); + local_backup_set_groups_[i] + .sets[k] + .get() + .parity_after_puncture.XorFromRaw(&db_chunk[offset * entry_size_]); } - // store the replacement + // Store the replacement yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { - // generate a random offset between 0 and ChunkSize - 1 + // Generate a random offset between 0 and chunk_size_ - 1 const auto offset = prg() & (chunk_size_ - 1); local_replacement_groups_[i].indices[k] = offset + i * chunk_size_; std::vector entry_slice(entry_size_); - std::memcpy(entry_slice.data(), &chunk[offset * entry_size_], + std::memcpy(entry_slice.data(), &db_chunk[offset * entry_size_], entry_size_ * sizeof(uint8_t)); - local_replacement_groups_[i].value[k] = + local_replacement_groups_[i].values[k] = DBEntry::DBEntryFromSlice(entry_slice); } } } +// Store results of sqrt(n) recent queries, serve duplicates locally while +// masking with a random distinct query void QueryServiceClient::SendDummySet() const { yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); - std::vector randSet(set_size_); + std::vector rand_set(set_size_); for (uint64_t i = 0; i < set_size_; i++) { - randSet[i] = prg() % chunk_size_ + i * chunk_size_; + rand_set[i] = prg() % chunk_size_ + i * chunk_size_; } - // send the random dummy set to the server - const auto query_msg = SerializeSetParityQueryMsg(set_size_, randSet); - context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); + // Send the random dummy set to the server + context_->SendAsync(context_->NextRank(), SerializeSetParityQuery(rand_set), + "SetParityQuery"); - const auto response_buf = - context_->Recv(context_->NextRank(), "SetParityQueryResponse"); - // auto parityQueryResponse = DeserializeSetParityQueryResponse(response_buf); + auto parity_query_response = DeserializeSetParityResponse( + context_->Recv(context_->NextRank(), "SetParityResponse")); } DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { - // make sure x is not in the local cache + // Make sure x is not in the local cache if (local_cache_.find(x) != local_cache_.end()) { SendDummySet(); return local_cache_[x]; @@ -234,105 +225,104 @@ DBEntry QueryServiceClient::OnlineSingleQuery(const uint64_t x) { // replacement // 3. The client sends the edited set to the server and gets the parity // 4. The client recovers the answer - uint64_t hitSetId = std::numeric_limits::max(); + uint64_t hit_set_id = std::numeric_limits::max(); - const uint64_t queryOffset = x % chunk_size_; - const uint64_t chunkId = x / chunk_size_; + const uint64_t query_offset = x % chunk_size_; + const uint64_t chunk_id = x / chunk_size_; for (uint64_t i = 0; i < primary_set_num_; i++) { const auto& set = primary_sets_[i]; - if (const bool isProgrammedMatch = - set.isProgrammed && chunkId == (set.programmedPoint / chunk_size_); - !isProgrammedMatch && - PRSetWithShortTag{set.tag}.MemberTestWithLongKeyAndTag( - long_key_, chunkId, queryOffset, chunk_size_)) { - hitSetId = i; + if (const bool is_programmed_match = + set.is_programmed && + chunk_id == (set.programmed_point / chunk_size_); + !is_programmed_match && + PRFSetWithShortTag{set.tag}.MemberTestWithLongKey( + long_key_, chunk_id, query_offset, chunk_size_)) { + hit_set_id = i; break; } } - DBEntry xVal = DBEntry::ZeroEntry(entry_size_); + DBEntry val = DBEntry::ZeroEntry(entry_size_); - if (hitSetId == std::numeric_limits::max()) { + if (hit_set_id == std::numeric_limits::max()) { if (local_miss_elements_.find(x) == local_miss_elements_.end()) { - SPDLOG_ERROR("No hit set found for %lu", x); + SPDLOG_ERROR("No hit set found for {}", x); } else { - xVal = local_miss_elements_[x]; - local_cache_[x] = xVal; + val = local_miss_elements_[x]; + local_cache_[x] = val; } SendDummySet(); - return xVal; + return val; } - // expand the set - const PRSetWithShortTag set{primary_sets_[hitSetId].tag}; - auto expandedSet = set.ExpandWithLongKey(long_key_, set_size_, chunk_size_); + // Expand the set + const PRFSetWithShortTag set{primary_sets_[hit_set_id].tag}; + auto expanded_set = set.ExpandWithLongKey(long_key_, set_size_, chunk_size_); - // manually program the set if the flag is set before - if (primary_sets_[hitSetId].isProgrammed) { - const uint64_t programmedChunkId = - primary_sets_[hitSetId].programmedPoint / chunk_size_; - expandedSet[programmedChunkId] = primary_sets_[hitSetId].programmedPoint; + // Manually program the set if the flag is set before + if (primary_sets_[hit_set_id].is_programmed) { + const uint64_t programmed_chunk_id = + primary_sets_[hit_set_id].programmed_point / chunk_size_; + expanded_set[programmed_chunk_id] = + primary_sets_[hit_set_id].programmed_point; } - // edit the set by replacing the chunk(x)-th element with a replacement - const uint64_t nxtAvailable = local_replacement_groups_[chunkId].consumed; - if (nxtAvailable == backup_set_num_per_chunk_) { - SPDLOG_ERROR("No replacement available for %lu", x); + // Edit the set by replacing the chunk(x)-th element with a replacement + const uint64_t next_available = local_replacement_groups_[chunk_id].consumed; + if (next_available == backup_set_num_per_chunk_) { + SPDLOG_ERROR("No replacement available for {}", x); SendDummySet(); - return xVal; + return val; } - // consume one replacement - const uint64_t repIndex = - local_replacement_groups_[chunkId].indices[nxtAvailable]; - const DBEntry repVal = local_replacement_groups_[chunkId].value[nxtAvailable]; - local_replacement_groups_[chunkId].consumed++; - expandedSet[chunkId] = repIndex; - - // send the edited set to the server - const auto query_msg = SerializeSetParityQueryMsg(set_size_, expandedSet); - context_->SendAsync(context_->NextRank(), query_msg, "SetParityQueryMsg"); - - const auto response_buf = - context_->Recv(context_->NextRank(), "SetParityQueryResponse"); - - const auto parityQueryResponse = - DeserializeSetParityQueryResponse(response_buf); - const auto& parity = std::get<0>(parityQueryResponse); - - // recover the answer - xVal = primary_sets_[hitSetId].parity; // the parity of the hit set - xVal.XorFromRaw(parity.data()); // xor the parity of the edited set - xVal.Xor(repVal); // xor the replacement value - - // update the local cache - local_cache_[x] = xVal; - - // refresh phase - if (local_backup_set_groups_[chunkId].consumed == backup_set_num_per_chunk_) { - SPDLOG_WARN("No backup set available for %lu", x); - return xVal; + // Consume one replacement + const uint64_t replace_index = + local_replacement_groups_[chunk_id].indices[next_available]; + const DBEntry replace_value = + local_replacement_groups_[chunk_id].values[next_available]; + local_replacement_groups_[chunk_id].consumed++; + expanded_set[chunk_id] = replace_index; + + // Send the edited set to the server + context_->SendAsync(context_->NextRank(), + SerializeSetParityQuery(expanded_set), "SetParityQuery"); + + const auto parity = DeserializeSetParityResponse( + context_->Recv(context_->NextRank(), "SetParityResponse")); + + // Recover the answer + val = primary_sets_[hit_set_id].parity; // The parity of the hit set + val.XorFromRaw(parity.data()); // XOR the parity of the edited set + val.Xor(replace_value); // XOR the replacement value + + // Update the local cache + local_cache_[x] = val; + + // Refresh phase + if (local_backup_set_groups_[chunk_id].consumed == + backup_set_num_per_chunk_) { + SPDLOG_WARN("No backup set available for {}", x); + return val; } - const DBEntry originalXVal = xVal; - const uint64_t consumed = local_backup_set_groups_[chunkId].consumed; - primary_sets_[hitSetId].tag = - local_backup_set_groups_[chunkId].sets[consumed].get().tag; - // backup set doesn't XOR the chunk(x)-th element in preparation - xVal.Xor(local_backup_set_groups_[chunkId] - .sets[consumed] - .get() - .parityAfterPuncture); - primary_sets_[hitSetId].parity = xVal; - primary_sets_[hitSetId].isProgrammed = true; - // for load balancing, the chunk(x)-th element differs from the one expanded - // via PRFEval on the tag - primary_sets_[hitSetId].programmedPoint = x; - local_backup_set_groups_[chunkId].consumed++; - - return originalXVal; + const DBEntry original_value = val; + const uint64_t consumed = local_backup_set_groups_[chunk_id].consumed; + primary_sets_[hit_set_id].tag = + local_backup_set_groups_[chunk_id].sets[consumed].get().tag; + // Backup set doesn't XOR the chunk(x)-th element in preprocessing + val.Xor(local_backup_set_groups_[chunk_id] + .sets[consumed] + .get() + .parity_after_puncture); + primary_sets_[hit_set_id].parity = val; + primary_sets_[hit_set_id].is_programmed = true; + // For load balancing, the chunk(x)-th element needs to be preserved + primary_sets_[hit_set_id].programmed_point = x; + local_backup_set_groups_[chunk_id].consumed++; + + return original_value; } std::vector QueryServiceClient::OnlineMultipleQueries( diff --git a/experimental/pir/piano/client.h b/experimental/pir/piano/client.h index d53fdf11..83688192 100644 --- a/experimental/pir/piano/client.h +++ b/experimental/pir/piano/client.h @@ -13,29 +13,57 @@ namespace pir::piano { struct LocalSet { + /** + * @brief Represents a compressed set in the primary table. + * + * @param tag Unique identifier for generating set elements via PRF. The j-th + * offset is calculated as PRF(msk, tag||j). + * @param parity XOR of all expanded set elements. + * @param programmed_point Indicates an element requiring manual replacement. + * Ensures balanced distribution by preserving query index in specific chunk. + * @param is_programmed Signals whether manual modification is needed. + */ LocalSet(const uint32_t tag, DBEntry parity, const uint64_t programmed_point, const bool is_programmed) : tag(tag), parity(std::move(parity)), - programmedPoint(programmed_point), - isProgrammed(is_programmed) {} + programmed_point(programmed_point), + is_programmed(is_programmed) {} - uint32_t tag; // the tag of the set + uint32_t tag; DBEntry parity; - uint64_t programmedPoint; // identifier for the element replaced after - // refresh differing from those expanded by PRFEval - bool isProgrammed; + // Identifier for the element replaced after refresh, differing from those + // expanded by PRFEval + uint64_t programmed_point; + bool is_programmed; }; struct LocalBackupSet { + /** + * @brief Represents a compressed set in the backup table. + * + * @param tag Functions similarly to the tag in the primary table, serving as + * a unique identifier. + * @param parity_after_puncture The XOR result of all elements in the set, + * excluding an element in a specific chunk. This is designed to reduce the + * computation needed when refreshing sets in the primary table. + */ LocalBackupSet(const uint32_t tag, DBEntry parity_after_puncture) - : tag(tag), parityAfterPuncture(std::move(parity_after_puncture)) {} + : tag(tag), parity_after_puncture(std::move(parity_after_puncture)) {} uint32_t tag; - DBEntry parityAfterPuncture; + DBEntry parity_after_puncture; }; struct LocalBackupSetGroup { + /** + * @brief Organizes backup sets into predefined chunks. + * + * @param consumed Indicates the current position of consumed sets within this + * group. + * @param sets Contains all backup sets related to a specific chunk, where + * their parity_after_puncture values exclude elements from this chunk. + */ LocalBackupSetGroup( const uint64_t consumed, const std::vector>& sets) @@ -46,21 +74,29 @@ struct LocalBackupSetGroup { }; struct LocalReplacementGroup { + /** + * @brief Stores replacement entries for each chunk. + * + * @param consumed Indicates the current position of consumed replacement + * entries within this chunk. + * @param indices Randomly sampled indices generated from the current chunk. + * @param values Values corresponding to the sampled indices. + */ LocalReplacementGroup(const uint64_t consumed, const std::vector& indices, - const std::vector& value) - : consumed(consumed), indices(indices), value(value) {} + const std::vector& values) + : consumed(consumed), indices(indices), values(values) {} uint64_t consumed; std::vector indices; - std::vector value; + std::vector values; }; class QueryServiceClient { public: - QueryServiceClient(uint64_t entry_num, uint64_t thread_num, - uint64_t entry_size, - std::shared_ptr context); + QueryServiceClient(std::shared_ptr context, + uint64_t entry_num, uint64_t thread_num, + uint64_t entry_size); void Initialize(); void InitializeLocalSets(); @@ -69,14 +105,19 @@ class QueryServiceClient { DBEntry OnlineSingleQuery(uint64_t x); std::vector OnlineMultipleQueries( const std::vector& queries); - uint64_t getTotalQueryNumber() const { return total_query_num_; }; + uint64_t GetTotalQueryNumber() const { return total_query_num_; }; private: - static constexpr uint64_t kFailureProbLog2 = 40; - uint64_t total_query_num_{}; - uint64_t entry_num_; - uint64_t thread_num_; + // Statistical security parameter, representing log base 2 + const uint64_t log2_k_ = 40; + // Converts log2_k_ from base-2 to natural logarithm using the change of base + // formula + const double natural_log_k_ = std::log(2) * static_cast(log2_k_); + std::shared_ptr context_; + uint64_t total_query_num_{}; + uint64_t entry_num_{}; + uint64_t thread_num_{}; uint64_t chunk_size_{}; uint64_t set_size_{}; diff --git a/experimental/pir/piano/piano.proto b/experimental/pir/piano/piano.proto index 058de506..8b256bc0 100644 --- a/experimental/pir/piano/piano.proto +++ b/experimental/pir/piano/piano.proto @@ -2,29 +2,18 @@ syntax = "proto3"; package pir.piano; -message FetchFullDbMsg { +message FetchFullDbProto { uint64 dummy = 1; } -message DbChunk { - uint64 chunk_id = 1; - uint64 chunk_size = 2; - bytes chunks = 3; +message DbChunkProto { + bytes chunks = 1; } -message SetParityQueryMsg { - uint64 set_size = 1; - repeated uint64 indices = 2; +message SetParityQueryProto { + repeated uint64 indices = 1; } -message SetParityQueryResponse { +message SetParityResponseProto { bytes parity = 1; - uint64 server_compute_time = 2; -} - -message QueryRequest { - oneof request { - FetchFullDbMsg fetch_full_db = 1; - SetParityQueryMsg set_parity_query = 2; - } } diff --git a/experimental/pir/piano/piano_benchmark.cc b/experimental/pir/piano/piano_benchmark.cc index 9ae168c3..40a10063 100644 --- a/experimental/pir/piano/piano_benchmark.cc +++ b/experimental/pir/piano/piano_benchmark.cc @@ -14,75 +14,29 @@ namespace { -std::vector GenerateQueries(const uint64_t query_num, - const uint64_t entry_num) { +std::vector GenTestQueries(const uint64_t query_num, + const uint64_t entry_num) { std::vector queries; queries.reserve(query_num); - yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t q = 0; q < query_num; ++q) { queries.push_back(prg() % entry_num); } - return queries; } -std::vector FNVHash(uint64_t key) { - constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; - uint64_t hash = FNV_offset_basis; - - for (int i = 0; i < 8; ++i) { - constexpr uint64_t FNV_prime = 1099511628211ULL; - const auto byte = static_cast(key & 0xFF); - hash ^= static_cast(byte); - hash *= FNV_prime; - key >>= 8; - } - - std::vector hash_bytes(8); - for (size_t i = 0; i < 8; ++i) { - hash_bytes[i] = static_cast((hash >> (i * 8)) & 0xFF); - } - - return hash_bytes; -} - std::vector CreateDatabase(const uint64_t entry_size, const uint64_t entry_num, const uint64_t db_seed) { - const auto [ChunkSize, SetSize] = pir::piano::GenParams(entry_num); - std::vector DB; - DB.assign(ChunkSize * SetSize * entry_size, 0); - - for (uint64_t i = 0; i < DB.size() / entry_size; ++i) { - auto entry = - pir::piano::DBEntry::GenDBEntry(entry_size, db_seed, i, FNVHash); - std::memcpy(&DB[i * entry_size], entry.data().data(), + std::vector database; + database.assign(entry_num * entry_size, 0); + for (uint64_t i = 0; i < entry_num; ++i) { + auto entry = pir::piano::DBEntry::GenDBEntry(entry_size, db_seed, i, + pir::piano::FNVHash); + std::memcpy(&database[i * entry_size], entry.GetData().data(), entry_size * sizeof(uint8_t)); } - - return DB; -} - -void SetupAndRunServer( - const std::shared_ptr& server_context, - const uint64_t entry_size, const uint64_t entry_num, - std::promise& exit_signal, std::vector& db) { - const auto [ChunkSize, SetSize] = pir::piano::GenParams(entry_num); - pir::piano::QueryServiceServer server(db, server_context, SetSize, ChunkSize, - entry_size); - server.Start(exit_signal.get_future()); -} - -std::vector SetupAndRunClient( - const uint64_t entry_num, const uint64_t thread_num, - const uint64_t entry_size, - const std::shared_ptr& client_context, - const std::vector& queries) { - pir::piano::QueryServiceClient client(entry_num, thread_num, entry_size, - client_context); - client.FetchFullDB(); - return client.OnlineMultipleQueries(queries); + return database; } } // namespace @@ -93,35 +47,54 @@ static void BM_PianoPir(benchmark::State& state) { const uint64_t entry_size = state.range(0); const uint64_t entry_num = state.range(1) / entry_size / CHAR_BIT; const uint64_t query_num = state.range(2); - constexpr uint64_t db_seed = 2315127; - uint64_t thread_num = 8; + const uint64_t db_seed = yacl::crypto::FastRandU64(); + const uint64_t thread_num = 8; - constexpr int kWorldSize = 2; - const auto contexts = yacl::link::test::SetupWorld(kWorldSize); - yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); + const int world_size = 2; + const auto contexts = yacl::link::test::SetupWorld(world_size); - auto db = CreateDatabase(entry_size, entry_num, db_seed); - auto queries = GenerateQueries(query_num, entry_num); + auto database = CreateDatabase(entry_size, entry_num, db_seed); + auto queries = GenTestQueries(query_num, entry_num); state.ResumeTiming(); - std::promise exitSignal; - auto server_future = - std::async(std::launch::async, SetupAndRunServer, contexts[0], - entry_size, entry_num, std::ref(exitSignal), std::ref(db)); - - auto client_future = - std::async(std::launch::async, SetupAndRunClient, entry_num, thread_num, - entry_size, contexts[1], std::cref(queries)); - auto results = client_future.get(); - - exitSignal.set_value(); - server_future.get(); + pir::piano::QueryServiceServer server(contexts[0], database, entry_num, + entry_size); + pir::piano::QueryServiceClient client(contexts[1], entry_num, thread_num, + entry_size); + + auto client_preprocess_future = + std::async(std::launch::async, [&client]() { client.FetchFullDB(); }); + + auto server_preprocess_future = std::async( + std::launch::async, [&server]() { server.HandleFetchFullDB(); }); + + client_preprocess_future.get(); + server_preprocess_future.get(); + + std::promise stop_signal; + std::future stop_future = stop_signal.get_future(); + + auto client_query_future = + std::async(std::launch::async, [&client, &queries]() { + return client.OnlineMultipleQueries(queries); + }); + + auto server_query_future = + std::async(std::launch::async, [&server, &stop_future]() { + server.HandleMultipleQueries(stop_future); + }); + + const auto pir_results = client_query_future.get(); + stop_signal.set_value(); + server_query_future.get(); } } BENCHMARK(BM_PianoPir) ->Unit(benchmark::kMillisecond) - ->Args({4, 1 << 20, 1000}) - ->Args({4, 2 << 20, 1000}) - ->Args({8, 1 << 20, 1000}) - ->Args({8, 2 << 20, 1000}); + ->Args({4, 32 << 20, 1000}) + ->Args({4, 64 << 20, 1000}) + ->Args({4, 128 << 20, 1000}) + ->Args({8, 64 << 20, 1000}) + ->Args({8, 128 << 20, 1000}) + ->Args({8, 256 << 20, 1000}); diff --git a/experimental/pir/piano/piano_test.cc b/experimental/pir/piano/piano_test.cc index 50a72273..fcefffe8 100644 --- a/experimental/pir/piano/piano_test.cc +++ b/experimental/pir/piano/piano_test.cc @@ -9,7 +9,6 @@ #include #include "experimental/pir/piano/client.h" -#include "experimental/pir/piano/serialize.h" #include "experimental/pir/piano/server.h" #include "experimental/pir/piano/util.h" #include "gtest/gtest.h" @@ -27,55 +26,29 @@ struct TestParams { namespace pir::piano { -std::vector GenerateQueries(const uint64_t query_num, - const uint64_t entry_num) { +// Generate a set of uniformly distributed random query indices within the +// database entry range for testing purposes +std::vector GenTestQueries(const uint64_t query_num, + const uint64_t entry_num) { std::vector queries; queries.reserve(query_num); - yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t q = 0; q < query_num; ++q) { queries.push_back(prg() % entry_num); } - return queries; } -std::vector RunClient(QueryServiceClient& client, - const std::vector& queries) { - client.FetchFullDB(); - return client.OnlineMultipleQueries(queries); -} - -std::vector FNVHash(uint64_t key) { - constexpr uint64_t FNV_offset_basis = 14695981039346656037ULL; - uint64_t hash = FNV_offset_basis; - - for (int i = 0; i < 8; ++i) { - constexpr uint64_t FNV_prime = 1099511628211ULL; - const auto byte = static_cast(key & 0xFF); - hash ^= static_cast(byte); - hash *= FNV_prime; - key >>= 8; - } - - std::vector hash_bytes(8); - for (size_t i = 0; i < 8; ++i) { - hash_bytes[i] = static_cast((hash >> (i * 8)) & 0xFF); - } - - return hash_bytes; -} - -std::vector getResults(const std::vector& queries, - const TestParams& params) { +// Simulate direct database lookup using plain-text indices to verify the +// correctness of PIR scheme +std::vector GetPlainResults(const std::vector& queries, + const TestParams& params) { std::vector expected_results; expected_results.reserve(queries.size()); - for (const auto& x : queries) { expected_results.push_back( DBEntry::GenDBEntry(params.entry_size, params.db_seed, x, FNVHash)); } - return expected_results; } @@ -83,57 +56,75 @@ class PianoTest : public testing::TestWithParam {}; TEST_P(PianoTest, Works) { auto params = GetParam(); - constexpr int kWorldSize = 2; + const int world_size = 2; uint64_t entry_num = params.db_size / params.entry_size / CHAR_BIT; - const auto contexts = yacl::link::test::SetupWorld(kWorldSize); + const auto contexts = yacl::link::test::SetupWorld(world_size); - SPDLOG_INFO("DB N: %lu, Entry Size %lu Bytes, DB Size %lu MB\n", entry_num, - params.entry_size, entry_num * params.entry_size / 1024 / 1024); + SPDLOG_INFO( + "Database summary: total entries: {}, each entry size: {} bytes, total " + "database size: {:.2f} MB", + entry_num, params.entry_size, + static_cast(entry_num * params.entry_size) / (1024 * 1024)); - auto [ChunkSize, SetSize] = GenParams(entry_num); - SPDLOG_INFO("Chunk Size: %lu, Set Size: %lu\n", ChunkSize, SetSize); + auto [chunk_size, set_size] = GenChunkParams(entry_num); + SPDLOG_INFO("Generated parameters: chunk_size: {}, set_size: {}", chunk_size, + set_size); - std::vector DB; - DB.assign(ChunkSize * SetSize * params.entry_size, 0); - SPDLOG_INFO("DB Real N: %lu\n", DB.size()); + SPDLOG_INFO("Generating database with seed: {}", params.db_seed); + std::vector database; + database.assign(entry_num * params.entry_size, 0); - for (uint64_t i = 0; i < DB.size() / params.entry_size; ++i) { + for (uint64_t i = 0; i < entry_num; ++i) { auto entry = DBEntry::GenDBEntry(params.entry_size, params.db_seed, i, FNVHash); - std::memcpy(&DB[i * params.entry_size], entry.data().data(), + std::memcpy(&database[i * params.entry_size], entry.GetData().data(), params.entry_size * sizeof(uint8_t)); } - QueryServiceClient client(entry_num, params.thread_num, params.entry_size, - contexts[1]); + SPDLOG_INFO("Initializing query service: server and client"); + QueryServiceServer server(contexts[0], database, entry_num, + params.entry_size); + QueryServiceClient client(contexts[1], entry_num, params.thread_num, + params.entry_size); const auto actual_query_num = params.is_total_query_num - ? client.getTotalQueryNumber() + ? client.GetTotalQueryNumber() : params.query_num; - const auto queries = GenerateQueries(actual_query_num, entry_num); + SPDLOG_INFO("Generating {} test queries", actual_query_num); + const auto queries = GenTestQueries(actual_query_num, entry_num); - yacl::link::RecvTimeoutGuard guard(contexts[0], 1000000); - QueryServiceServer server(DB, contexts[0], SetSize, ChunkSize, - params.entry_size); + SPDLOG_INFO("Starting preprocess phase"); + auto client_preprocess_future = + std::async(std::launch::async, [&client]() { client.FetchFullDB(); }); - std::promise exitSignal; - std::future futureObj = exitSignal.get_future(); - auto server_future = - std::async(std::launch::async, &QueryServiceServer::Start, - std::ref(server), std::move(futureObj)); - auto client_future = std::async(std::launch::async, RunClient, - std::ref(client), std::cref(queries)); + auto server_preprocess_future = std::async( + std::launch::async, [&server]() { server.HandleFetchFullDB(); }); - const auto results = client_future.get(); - const auto expected_results = getResults(queries, params); + client_preprocess_future.get(); + server_preprocess_future.get(); - for (size_t i = 0; i < results.size(); ++i) { - EXPECT_EQ(results[i].data(), expected_results[i].data()) + SPDLOG_INFO("Starting online query phase"); + std::promise stop_signal; + std::future stop_future = stop_signal.get_future(); + + auto client_query_future = std::async( + std::launch::async, + [&client, &queries]() { return client.OnlineMultipleQueries(queries); }); + + auto server_query_future = std::async( + std::launch::async, + [&server, &stop_future]() { server.HandleMultipleQueries(stop_future); }); + + const auto pir_results = client_query_future.get(); + const auto expected_results = GetPlainResults(queries, params); + stop_signal.set_value(); + server_query_future.get(); + + SPDLOG_INFO("Verifying {} query results", pir_results.size()); + for (size_t i = 0; i < pir_results.size(); ++i) { + EXPECT_EQ(pir_results[i].GetData(), expected_results[i].GetData()) << "Mismatch at index " << queries[i]; } - - exitSignal.set_value(); - server_future.get(); } // [8m, 128m, 256m] units are in bits diff --git a/experimental/pir/piano/serialize.h b/experimental/pir/piano/serialize.h index 0e59c52d..ea0ce2c9 100644 --- a/experimental/pir/piano/serialize.h +++ b/experimental/pir/piano/serialize.h @@ -1,7 +1,5 @@ #pragma once -#include -#include #include #include "experimental/pir/piano/util.h" @@ -11,84 +9,69 @@ namespace pir::piano { -inline yacl::Buffer SerializeFetchFullDBMsg(const uint64_t dummy) { - QueryRequest proto; - FetchFullDbMsg* fetch_full_db_msg = proto.mutable_fetch_full_db(); - fetch_full_db_msg->set_dummy(dummy); - +inline yacl::Buffer SerializeFetchFullDB(const uint64_t dummy) { + FetchFullDbProto proto; + proto.set_dummy(dummy); yacl::Buffer buf(proto.ByteSizeLong()); proto.SerializeToArray(buf.data(), buf.size()); - return buf; } -inline uint64_t DeserializeFetchFullDBMsg(const yacl::Buffer& buf) { - QueryRequest proto; +inline uint64_t DeserializeFetchFullDB(const yacl::Buffer& buf) { + FetchFullDbProto proto; proto.ParseFromArray(buf.data(), buf.size()); - return proto.fetch_full_db().dummy(); + return proto.dummy(); } -inline yacl::Buffer SerializeDBChunk(const uint64_t chunk_id, - const uint64_t chunk_size, - const std::vector& chunk) { - DbChunk proto; - proto.set_chunk_id(chunk_id); - proto.set_chunk_size(chunk_size); +inline yacl::Buffer SerializeDBChunk(const std::vector& chunk) { + DbChunkProto proto; proto.set_chunks(chunk.data(), chunk.size()); yacl::Buffer buf(proto.ByteSizeLong()); proto.SerializeToArray(buf.data(), buf.size()); return buf; } -inline std::tuple> DeserializeDBChunk( - const yacl::Buffer& buf) { - DbChunk proto; +inline std::vector DeserializeDBChunk(const yacl::Buffer& buf) { + DbChunkProto proto; proto.ParseFromArray(buf.data(), buf.size()); std::vector chunk(proto.chunks().begin(), proto.chunks().end()); - return {proto.chunk_id(), proto.chunk_size(), chunk}; + return chunk; } -inline yacl::Buffer SerializeSetParityQueryMsg( - const uint64_t set_size, const std::vector& indices) { - QueryRequest proto; - SetParityQueryMsg* set_parity_query = proto.mutable_set_parity_query(); - set_parity_query->set_set_size(set_size); +inline yacl::Buffer SerializeSetParityQuery( + const std::vector& indices) { + SetParityQueryProto proto; for (const auto& index : indices) { - set_parity_query->add_indices(index); + proto.add_indices(index); } - yacl::Buffer buf(proto.ByteSizeLong()); proto.SerializeToArray(buf.data(), buf.size()); - return buf; } -inline std::pair> DeserializeSetParityQueryMsg( +inline std::vector DeserializeSetParityQuery( const yacl::Buffer& buf) { - QueryRequest proto; + SetParityQueryProto proto; proto.ParseFromArray(buf.data(), buf.size()); - const auto& set_parity_query = proto.set_parity_query(); - std::vector indices(set_parity_query.indices().begin(), - set_parity_query.indices().end()); - return {set_parity_query.set_size(), indices}; + std::vector indices(proto.indices().begin(), proto.indices().end()); + return indices; } -inline yacl::Buffer SerializeSetParityQueryResponse( - const std::vector& parity, const uint64_t server_compute_time) { - SetParityQueryResponse proto; +inline yacl::Buffer SerializeSetParityResponse( + const std::vector& parity) { + SetParityResponseProto proto; proto.set_parity(parity.data(), parity.size()); - proto.set_server_compute_time(server_compute_time); yacl::Buffer buf(proto.ByteSizeLong()); proto.SerializeToArray(buf.data(), buf.size()); return buf; } -inline std::pair, uint64_t> -DeserializeSetParityQueryResponse(const yacl::Buffer& buf) { - SetParityQueryResponse proto; +inline std::vector DeserializeSetParityResponse( + const yacl::Buffer& buf) { + SetParityResponseProto proto; proto.ParseFromArray(buf.data(), buf.size()); std::vector parity(proto.parity().begin(), proto.parity().end()); - return {parity, proto.server_compute_time()}; + return parity; } } // namespace pir::piano diff --git a/experimental/pir/piano/server.cc b/experimental/pir/piano/server.cc index 3bfc7e1a..6a10d300 100644 --- a/experimental/pir/piano/server.cc +++ b/experimental/pir/piano/server.cc @@ -3,96 +3,84 @@ namespace pir::piano { QueryServiceServer::QueryServiceServer( - std::vector& db, std::shared_ptr context, - const uint64_t set_size, const uint64_t chunk_size, - const uint64_t entry_size) - : db_(std::move(db)), - context_(std::move(context)), - set_size_(set_size), - chunk_size_(chunk_size), - entry_size_(entry_size) {} - -void QueryServiceServer::Start(const std::future& stop_signal) { - while (stop_signal.wait_for(std::chrono::milliseconds(1)) == - std::future_status::timeout) { - auto request_data = context_->Recv(context_->NextRank(), "request_data"); - HandleRequest(request_data); - } + std::shared_ptr context, std::vector& db, + const uint64_t entry_num, const uint64_t entry_size) + : context_(std::move(context)), + db_(std::move(db)), + entry_num_(entry_num), + entry_size_(entry_size) { + std::tie(chunk_size_, set_size_) = GenChunkParams(entry_num_); + AlignDBToChunkBoundary(); } -void QueryServiceServer::HandleRequest(const yacl::Buffer& request_data) { - QueryRequest proto; - proto.ParseFromArray(request_data.data(), request_data.size()); +void QueryServiceServer::AlignDBToChunkBoundary() { + if (entry_num_ < chunk_size_ * set_size_) { + const uint64_t padding_num = (chunk_size_ * set_size_) - entry_num_; + const uint64_t seed = yacl::crypto::FastRandU64(); - switch (proto.request_case()) { - case QueryRequest::kFetchFullDb: { - // uint64_t dummy = DeserializeFetchFullDBMsg(request_data); - ProcessFetchFullDB(); - break; + db_.reserve(db_.size() + (padding_num * entry_size_)); + for (uint64_t i = 0; i < padding_num; ++i) { + auto padding_entry = DBEntry::GenDBEntry(entry_size_, seed, i, FNVHash); + db_.insert(db_.end(), padding_entry.GetData().begin(), + padding_entry.GetData().end()); } - case QueryRequest::kSetParityQuery: { - const auto parityQuery = DeserializeSetParityQueryMsg(request_data); - const auto& indices = std::get<1>(parityQuery); - - auto [parity, server_compute_time] = ProcessSetParityQuery(indices); - const auto response_buf = - SerializeSetParityQueryResponse(parity, server_compute_time); - context_->SendAsync(context_->NextRank(), response_buf, - "SetParityQueryResponse"); - break; - } - default: - SPDLOG_ERROR("Unknown request type."); + entry_num_ += padding_num; } } -void QueryServiceServer::ProcessFetchFullDB() { +void QueryServiceServer::HandleFetchFullDB() { + DeserializeFetchFullDB(context_->Recv(context_->NextRank(), "FetchFullDB")); for (uint64_t i = 0; i < set_size_; ++i) { const uint64_t down = i * chunk_size_; - uint64_t up = (i + 1) * chunk_size_; - up = std::min(up, db_.size() / entry_size_); + const uint64_t up = (i + 1) * chunk_size_; std::vector chunk(db_.begin() + down * entry_size_, db_.begin() + up * entry_size_); - auto chunk_buf = SerializeDBChunk(i, chunk.size(), chunk); try { - context_->SendAsync(context_->NextRank(), chunk_buf, "FetchFullDBChunk"); + context_->SendAsync(context_->NextRank(), SerializeDBChunk(chunk), + "DBChunk"); } catch (const std::exception& e) { - SPDLOG_ERROR("Failed to send a chunk."); + SPDLOG_ERROR("Failed to send a chunk: {}", e.what()); return; } } } -std::pair, uint64_t> -QueryServiceServer::ProcessSetParityQuery( - const std::vector& indices) { - const auto start = std::chrono::high_resolution_clock::now(); - std::vector parity = HandleSetParityQuery(indices); - const auto end = std::chrono::high_resolution_clock::now(); - const auto duration = - std::chrono::duration_cast(end - start).count(); - return {parity, duration}; +void QueryServiceServer::HandleMultipleQueries( + const std::future& stop_signal) { + while (stop_signal.wait_for(std::chrono::milliseconds(5)) == + std::future_status::timeout) { + HandleQueryRequest(); + } +} + +void QueryServiceServer::HandleQueryRequest() { + const auto indices = DeserializeSetParityQuery( + context_->Recv(context_->NextRank(), "SetParityQuery")); + + const std::vector parity = ProcessSetParityQuery(indices); + context_->SendAsync(context_->NextRank(), SerializeSetParityResponse(parity), + "SetParityResponse"); } -std::vector QueryServiceServer::HandleSetParityQuery( +std::vector QueryServiceServer::ProcessSetParityQuery( const std::vector& indices) { DBEntry parity = DBEntry::ZeroEntry(entry_size_); for (const auto& index : indices) { DBEntry entry = DBAccess(index); parity.Xor(entry); } - return parity.data(); + return parity.GetData(); } -DBEntry QueryServiceServer::DBAccess(const uint64_t id) { - if (const size_t num_entries = db_.size() / entry_size_; id < num_entries) { +DBEntry QueryServiceServer::DBAccess(const uint64_t idx) { + if (idx < entry_num_) { std::vector slice(entry_size_); - std::copy(db_.begin() + id * entry_size_, - db_.begin() + (id + 1) * entry_size_, slice.begin()); + std::copy(db_.begin() + idx * entry_size_, + db_.begin() + (idx + 1) * entry_size_, slice.begin()); return DBEntry::DBEntryFromSlice(slice); } - SPDLOG_ERROR("DBAccess: id {} out of range", id); + SPDLOG_ERROR("DBAccess: idx {} out of range", idx); return DBEntry::ZeroEntry(entry_size_); } diff --git a/experimental/pir/piano/server.h b/experimental/pir/piano/server.h index 40096215..dc3e1fa6 100644 --- a/experimental/pir/piano/server.h +++ b/experimental/pir/piano/server.h @@ -5,7 +5,6 @@ #include #include #include -#include #include #include "experimental/pir/piano/serialize.h" @@ -16,39 +15,44 @@ namespace pir::piano { class QueryServiceServer { public: - // Constructor: initializes the server with a database, context, set_size, and - // chunk_size - QueryServiceServer(std::vector& db, - std::shared_ptr context, - uint64_t set_size, uint64_t chunk_size, + QueryServiceServer(std::shared_ptr context, + std::vector& db, uint64_t entry_num, uint64_t entry_size); - // Starts the server to handle incoming requests - void Start(const std::future& stop_signal); - - // Handles the incoming request based on its type - void HandleRequest(const yacl::Buffer& request_data); - - // Processes a request to fetch the full database - void ProcessFetchFullDB(); - - // Processes a set parity query and returns the parity and server compute time - std::pair, uint64_t> ProcessSetParityQuery( - const std::vector& indices); + /** + * @brief Align the database to ensure uniformity and independence of query + * distribution. + * + * Pads the database with additional entries to complete the last chunk, + * guaranteeing consistent query distribution across all chunks. + */ + void AlignDBToChunkBoundary(); + + /** + * @brief Handles server-side full database retrieval request. + * + * Transfers database in chunks using a pipelined approach, sending one chunk + * at a time to minimize client-side storage requirements. + */ + void HandleFetchFullDB(); + void HandleQueryRequest(); + void HandleMultipleQueries(const std::future& stop_signal); private: - // Accesses the database and returns the corresponding entry - DBEntry DBAccess(uint64_t id); + // Access the database and return the entry corresponding to the index + DBEntry DBAccess(uint64_t idx); - // Handles a set parity query and returns the parity - std::vector HandleSetParityQuery( + // Process a set parity query by computing the XOR of all elements in the + // query set + std::vector ProcessSetParityQuery( const std::vector& indices); - std::vector db_; // The database std::shared_ptr context_; // The communication context - uint64_t set_size_; // The size of the set - uint64_t chunk_size_; // The size of each chunk - uint64_t entry_size_; // The size of database entry + std::vector db_; // The database + uint64_t set_size_{}; // The size of the set + uint64_t chunk_size_{}; // The size of each chunk + uint64_t entry_num_{}; // The number of database entry + uint64_t entry_size_{}; // The size of database entry }; } // namespace pir::piano diff --git a/experimental/pir/piano/util.cc b/experimental/pir/piano/util.cc index 42e955f3..c9dc302f 100644 --- a/experimental/pir/piano/util.cc +++ b/experimental/pir/piano/util.cc @@ -2,39 +2,22 @@ namespace pir::piano { -uint128_t SecureRandKey() { return yacl::crypto::SecureRandU128(); } - -uint64_t PRFEval(const uint128_t key, const uint64_t x) { - yacl::crypto::AES_KEY aes_key; - AES_set_encrypt_key(key, &aes_key); - - const auto src_block = static_cast(x); - std::vector plain_blocks(1); - plain_blocks[0] = src_block; - std::vector cipher_blocks(1); - - AES_ecb_encrypt_blks(aes_key, absl::MakeConstSpan(plain_blocks), - absl::MakeSpan(cipher_blocks)); - return static_cast(cipher_blocks[0]); -} - -// Generate ChunkSize and SetSize -std::pair GenParams(const uint64_t entry_num) { - const double targetChunkSize = 2 * std::sqrt(static_cast(entry_num)); - uint64_t ChunkSize = 1; +std::pair GenChunkParams(const uint64_t entry_num) { + const double target_chunk_size = + 2 * std::sqrt(static_cast(entry_num)); + uint64_t chunk_size = 1; - // Ensure ChunkSize is a power of 2 and not smaller than targetChunkSize - while (ChunkSize < static_cast(targetChunkSize)) { - ChunkSize *= 2; + // Ensure chunk_size is a power of 2 and not smaller than target_chunk_size + while (chunk_size < static_cast(target_chunk_size)) { + chunk_size *= 2; } - uint64_t SetSize = (entry_num + ChunkSize - 1) / ChunkSize; - // Round up to the next multiple of 4 - SetSize = (SetSize + 3) / 4 * 4; - - return {ChunkSize, SetSize}; + uint64_t set_size = (entry_num + chunk_size - 1) / chunk_size; + return {chunk_size, set_size}; } +uint128_t SecureRandKey() { return yacl::crypto::SecureRandU128(); } + yacl::crypto::AES_KEY GetLongKey(const uint128_t key) { yacl::crypto::AES_KEY aes_key; AES_set_encrypt_key(key, &aes_key); @@ -43,7 +26,6 @@ yacl::crypto::AES_KEY GetLongKey(const uint128_t key) { uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, const uint32_t tag, const uint64_t x) { - // Combine tag and x into a 128-bit block by shifting tag to the high 64 bits const uint128_t src_block = (static_cast(tag) << 64) + x; std::vector plain_blocks(1); plain_blocks[0] = src_block; @@ -53,25 +35,44 @@ uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, return static_cast(cipher_blocks[0]); } -std::vector PRSetWithShortTag::ExpandWithLongKey( +std::vector PRFSetWithShortTag::ExpandWithLongKey( const yacl::crypto::AES_KEY& long_key, const uint64_t set_size, const uint64_t chunk_size) const { - std::vector expandedSet(set_size); + std::vector expanded_set(set_size); for (uint64_t i = 0; i < set_size; i++) { - const uint64_t tmp = PRFEvalWithLongKeyAndTag(long_key, Tag, i); + const uint64_t tmp = PRFEvalWithLongKeyAndTag(long_key, tag, i); // Get the offset within the chunk const uint64_t offset = tmp & (chunk_size - 1); - expandedSet[i] = i * chunk_size + offset; + expanded_set[i] = i * chunk_size + offset; } - return expandedSet; + return expanded_set; } -bool PRSetWithShortTag::MemberTestWithLongKeyAndTag( +bool PRFSetWithShortTag::MemberTestWithLongKey( const yacl::crypto::AES_KEY& long_key, const uint64_t chunk_id, const uint64_t offset, const uint64_t chunk_size) const { // Ensure chunk_size is a power of 2 and compare offsets return offset == - (PRFEvalWithLongKeyAndTag(long_key, Tag, chunk_id) & (chunk_size - 1)); + (PRFEvalWithLongKeyAndTag(long_key, tag, chunk_id) & (chunk_size - 1)); +} + +std::vector FNVHash(uint64_t key) { + const uint64_t fnv_offset_basis = 14695981039346656037ULL; + uint64_t hash = fnv_offset_basis; + + for (int i = 0; i < 8; ++i) { + const uint64_t fnv_prime = 1099511628211ULL; + const auto byte = static_cast(key & 0xFF); + hash ^= static_cast(byte); + hash *= fnv_prime; + key >>= 8; + } + + std::vector hash_bytes(8); + for (size_t i = 0; i < 8; ++i) { + hash_bytes[i] = static_cast((hash >> (i * 8)) & 0xFF); + } + return hash_bytes; } } // namespace pir::piano diff --git a/experimental/pir/piano/util.h b/experimental/pir/piano/util.h index ba13813d..1963ae39 100644 --- a/experimental/pir/piano/util.h +++ b/experimental/pir/piano/util.h @@ -1,7 +1,5 @@ #pragma once -#include - #include #include #include @@ -17,7 +15,7 @@ class DBEntry { public: DBEntry() = default; - // entry_size represents the number of bytes in the DBEntry + // Total byte size of the database entry, initializing all bytes to zero explicit DBEntry(const size_t entry_size) : k_length_(entry_size), data_(entry_size, 0) {} @@ -25,8 +23,8 @@ class DBEntry { : k_length_(data.size()), data_(data) {} // Accessor for the underlying data - std::vector& data() { return data_; } - [[nodiscard]] const std::vector& data() const { return data_; } + std::vector& GetData() { return data_; } + [[nodiscard]] const std::vector& GetData() const { return data_; } // XOR operations void Xor(const DBEntry& other) { @@ -46,7 +44,7 @@ class DBEntry { return DBEntry(entry_size); } - // Generate a DBEntry based on a key and ID using a custom hash function + // Generate a DBEntry based on a seed and id using a custom hash function static DBEntry GenDBEntry( const size_t entry_size, const uint64_t key, const uint64_t id, const std::function(uint64_t)>& hash_func) { @@ -72,35 +70,61 @@ class DBEntry { std::vector data_; }; +/** + * @brief Generate optimal chunk and set size for PIR parameters. + * + * Calculate chunk size and set size based on total number of entries. + * The chunk size is set to the smallest power of 2 that is >= 2*sqrt(n), + * which optimizes modulo operations and overall scheme performance. + * + * @param entry_num Total number of entries in the database. + * @return A pair of {chunk_size, set_size}. + */ +std::pair GenChunkParams(uint64_t entry_num); + // Generate secure master key uint128_t SecureRandKey(); -// Evaluates PRF using a 128-bit key and returns a 64-bit result -uint64_t PRFEval(uint128_t key, uint64_t x); - -// Generate parameters for ChunkSize and SetSize -std::pair GenParams(uint64_t entry_num); - -// Returns a long key (AES expanded key) for PRF evaluation +// Return a long key (AES expanded key) for PRF evaluation yacl::crypto::AES_KEY GetLongKey(uint128_t key); -// PRF evaluation with a long key and tag, returns a 64-bit result +/** + * @brief Evaluate a Pseudo-Random Function (PRF) using AES-ECB encryption. + * + * Combine a 32-bit tag and 64-bit input into a 128-bit block, encrypt using a + * long key, and return the lower 64 bits of the encrypted block. + * + * @param long_key AES encryption key. + * @param tag 32-bit tag for domain separation. + * @param x 64-bit input value. + * @return 64-bit pseudo-random output. + */ uint64_t PRFEvalWithLongKeyAndTag(const yacl::crypto::AES_KEY& long_key, uint32_t tag, uint64_t x); -struct PRSetWithShortTag { - uint32_t Tag; +struct PRFSetWithShortTag { + uint32_t tag; - // Expands the set with a long key and tag + // Expand a short-tag set to a full set using the PRFEval [[nodiscard]] std::vector ExpandWithLongKey( const yacl::crypto::AES_KEY& long_key, uint64_t set_size, uint64_t chunk_size) const; - // Membership test with a long key and tag, to check if an ID belongs to the - // set - [[nodiscard]] bool MemberTestWithLongKeyAndTag( + // Check if an element belongs to the expanded set + [[nodiscard]] bool MemberTestWithLongKey( const yacl::crypto::AES_KEY& long_key, uint64_t chunk_id, uint64_t offset, uint64_t chunk_size) const; }; +/** + * @brief Convert a 64-bit key to an 8-byte hash using FNV-1a algorithm. + * + * Used for generating test data or random padding, independent of specific PIR + * scheme. + * + * @param key Input 64-bit value to be hashed. + * @return 8-byte hash representation. + */ +std::vector FNVHash(uint64_t key); + } // namespace pir::piano From 2b114c785ecc8d6b2137bbc279cc65c97ecfacba Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Wed, 18 Dec 2024 19:54:17 +0800 Subject: [PATCH 10/11] Updated the README for correct rendering of math formulas on github --- experimental/pir/piano/README.md | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/experimental/pir/piano/README.md b/experimental/pir/piano/README.md index ea315a8c..194942bd 100644 --- a/experimental/pir/piano/README.md +++ b/experimental/pir/piano/README.md @@ -13,44 +13,47 @@ ### Client-Side Initialization 1. **Primary Table Generation** - - Sample $M_1 = \sqrt{n} \log \kappa \cdot \alpha(\kappa)$ PRF keys, $\{sk_1, \ldots, sk_{M_1}\} \in \{0,1\}^{\lambda}$ - - Initialize parities $\{p_1, \ldots, p_{M_1}\}$ to zeros + - Sample $M_1 = \sqrt{n} \log \kappa \cdot \alpha(\kappa)$ PRF keys, $\\{sk_1, \ldots, sk_{M_1}\\} \in \\{0,1\\}^{\lambda}$ + - Initialize parities $\\{p_1, \ldots, p_{M_1}\\}$ to zeros 2. **Backup Table Generation** - - For each chunk $j \in \{0, 1, \ldots, \sqrt{n} - 1\}$: - - Sample $M_2 = \log \kappa \cdot \alpha(\kappa)$ PRF keys $\{sk_{j,1}, \ldots, sk_{j,M_2}\}$ - - Initialize chunk-specific parities $\{p_{j,1}, \ldots, p_{j,M_2}\}$ to zeros + + - For each chunk $j \in \\{0, 1, \ldots, \sqrt{n} - 1\\}$: + - Sample $M_2 = \log \kappa \cdot \alpha(\kappa)$ PRF keys $\\{sk_{j,1}, \ldots, sk_{j,M_2}\\}$ + - Initialize chunk-specific parities $\\{p_{j,1}, \ldots, p_{j,M_2}\\}$ to zeros ### Streaming Database Preprocessing For each database chunk $DB[j \sqrt{n} : (j+1) \sqrt{n}]$: + - Update primary table parity: for $i \in [M_1]$, $p_i \leftarrow p_i \oplus DB[\text{Set}(sk_i)[j]]$ - Store replacement entries: sample $M_2$ tuples $(r, DB[r])$ where $r$ is a random index from the current chunk -- Update backup table parity: for $i \in \{0, 1, \ldots, \sqrt{n} - 1\} \setminus \{j\}$ and $k \in [M_2]$, $p_{i,k} \leftarrow p_{i,k} \oplus DB[\text{Set}(sk_{i,k})[j]]$ +- Update backup table parity: for $i \in \\{0, 1, \ldots, \sqrt{n} - 1\\} \setminus \\{j\\}$ and $k \in [M_2]$, $p_{i,k} \leftarrow p_{i,k} \oplus DB[\text{Set}(sk_{i,k})[j]]$ - Delete current chunk from local storage ## Online Query Phase -Query Protocol for Index $x \in \{0, 1, \ldots, n-1\}$ +Query Protocol for Index $x \in \\{0, 1, \ldots, n-1\\}$ 1. **Query Execution** + - Find primary table hint $T_i = ((sk_i, x^{\prime}), p_i)$ where $x \in \text{Set}(sk_i, x^{\prime})$ - Locate chunk $j^* = \text{chunk}(x)$ - Find first unused replacement entry $(r, DB[r])$ - - Send set $S' = S \setminus \{j^* \to r\}$ to server + - Send set $S' = S \setminus \\{j^* \to r\\}$ to server - Server returns $q = \bigoplus_{k \in S'} DB[k]$ - Compute answer $\beta = q \oplus p_i \oplus DB[r]$ - 2. **Table Refresh Mechanism** - - Locate next unused backup entry $(sk_{j^*,k}, p_{j^*,k})$ - - If no entry exists, generate random $sk_{j^*,k}$ with $p_{j^*,k} = 0$ - - Update primary table with new entry: $((sk_{j^*,k}, x), p_{j^*,k} \oplus \beta)$ + + - Locate next unused backup entry $(sk_{j^\*,k}, p_{j^\*,k})$ + - If no entry exists, generate random $sk_{j^\*,k}$ with $p_{j^\*,k} = 0$ + - Update primary table with new entry: $((sk_{j^\*,k}, x), p_{j^\*,k} \oplus \beta)$ ## Theoretical Guarantees - **Client Storage**: $O(\sqrt{n})$ -- **Server Computation**: $O(\sqrt{n})$ -- **Communication Overhead**: $O(\sqrt{n})$ +- **Amortized Server Computation**: $O(\sqrt{n})$ +- **Amortized Communication Overhead**: $O(\sqrt{n})$ - **Query Complexity**: $O(\sqrt{n})$ ## References From d04c226e142dd7986ef572f3452cd2e5b1e95a7a Mon Sep 17 00:00:00 2001 From: cxiao129 Date: Fri, 20 Dec 2024 14:44:21 +0800 Subject: [PATCH 11/11] Move to specified directory, optimize global variables, and modify backup hint generation logic --- .../pir/piano/BUILD.bazel | 0 .../pir/piano/README.md | 0 .../pir/piano/client.cc | 41 ++++++++----------- .../pir/piano/client.h | 17 ++++---- .../pir/piano/piano.proto | 0 .../pir/piano/piano_benchmark.cc | 6 +-- .../pir/piano/piano_test.cc | 6 +-- .../pir/piano/serialize.h | 4 +- .../pir/piano/server.cc | 2 +- .../pir/piano/server.h | 4 +- .../pir/piano/util.cc | 2 +- {experimental => experiment}/pir/piano/util.h | 0 12 files changed, 38 insertions(+), 44 deletions(-) rename {experimental => experiment}/pir/piano/BUILD.bazel (100%) rename {experimental => experiment}/pir/piano/README.md (100%) rename {experimental => experiment}/pir/piano/client.cc (90%) rename {experimental => experiment}/pir/piano/client.h (92%) rename {experimental => experiment}/pir/piano/piano.proto (100%) rename {experimental => experiment}/pir/piano/piano_benchmark.cc (96%) rename {experimental => experiment}/pir/piano/piano_test.cc (97%) rename {experimental => experiment}/pir/piano/serialize.h (96%) rename {experimental => experiment}/pir/piano/server.cc (98%) rename {experimental => experiment}/pir/piano/server.h (95%) rename {experimental => experiment}/pir/piano/util.cc (98%) rename {experimental => experiment}/pir/piano/util.h (100%) diff --git a/experimental/pir/piano/BUILD.bazel b/experiment/pir/piano/BUILD.bazel similarity index 100% rename from experimental/pir/piano/BUILD.bazel rename to experiment/pir/piano/BUILD.bazel diff --git a/experimental/pir/piano/README.md b/experiment/pir/piano/README.md similarity index 100% rename from experimental/pir/piano/README.md rename to experiment/pir/piano/README.md diff --git a/experimental/pir/piano/client.cc b/experiment/pir/piano/client.cc similarity index 90% rename from experimental/pir/piano/client.cc rename to experiment/pir/piano/client.cc index 65f20c74..bcc175c6 100644 --- a/experimental/pir/piano/client.cc +++ b/experiment/pir/piano/client.cc @@ -1,4 +1,4 @@ -#include "experimental/pir/piano/client.h" +#include "experiment/pir/piano/client.h" namespace pir::piano { @@ -22,9 +22,9 @@ void QueryServiceClient::Initialize() { // Maximum number of queries supported by a single preprocessing // Let α(κ) be any super-constant function, i.e., α(κ) = w(1) // Chosen log(log(κ)): grows slowly but surely > any constant as κ → ∞ - total_query_num_ = - static_cast(std::sqrt(static_cast(entry_num_)) * - natural_log_k_ * std::log(natural_log_k_)); + total_query_num_ = static_cast( + std::sqrt(static_cast(entry_num_)) * kStatisticalSecurityLn * + std::log(kStatisticalSecurityLn)); std::tie(chunk_size_, set_size_) = GenChunkParams(entry_num_); @@ -42,7 +42,8 @@ void QueryServiceClient::Initialize() { // The probability that the client runs out of hints in a backup group is // negligible in κ backup_set_num_per_chunk_ = static_cast( - static_cast(log2_k_) * natural_log_k_ * std::log(natural_log_k_)); + static_cast(kStatisticalSecurityLog2) * kStatisticalSecurityLn * + std::log(kStatisticalSecurityLn)); backup_set_num_per_chunk_ = (backup_set_num_per_chunk_ + thread_num_ - 1) / thread_num_ * thread_num_; @@ -112,9 +113,9 @@ void QueryServiceClient::FetchFullDB() { // Make sure all sets are covered const uint64_t primary_set_per_thread = - ((primary_set_num_ + thread_num_ - 1) / thread_num_) + 1; + (primary_set_num_ + thread_num_ - 1) / thread_num_; const uint64_t backup_set_per_thread = - ((total_backup_set_num_ + thread_num_ - 1) / thread_num_) + 1; + (total_backup_set_num_ + thread_num_ - 1) / thread_num_; for (uint64_t tid = 0; tid < thread_num_; tid++) { uint64_t start_index = tid * primary_set_per_thread; @@ -141,11 +142,15 @@ void QueryServiceClient::FetchFullDB() { // Update the parities for the backup hints for (uint64_t j = start_index_backup; j < end_index_backup; j++) { - const auto tmp = - PRFEvalWithLongKeyAndTag(long_key_, local_backup_sets_[j].tag, i); - const auto offset = tmp & (chunk_size_ - 1); - local_backup_sets_[j].parity_after_puncture.XorFromRaw( - &db_chunk[offset * entry_size_]); + // Skip if backup set belongs to chunk i + if (j < i * backup_set_num_per_chunk_ || + j >= (i + 1) * backup_set_num_per_chunk_) { + const auto tmp = PRFEvalWithLongKeyAndTag( + long_key_, local_backup_sets_[j].tag, i); + const auto offset = tmp & (chunk_size_ - 1); + local_backup_sets_[j].parity_after_puncture.XorFromRaw( + &db_chunk[offset * entry_size_]); + } } }); } @@ -169,18 +174,6 @@ void QueryServiceClient::FetchFullDB() { } } - // For the i-th group of backups, leave the i-th chunk as blank - // To do that, we just XOR the i-th chunk's value again - for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { - const auto tag = local_backup_set_groups_[i].sets[k].get().tag; - const auto tmp = PRFEvalWithLongKeyAndTag(long_key_, tag, i); - const auto offset = tmp & (chunk_size_ - 1); - local_backup_set_groups_[i] - .sets[k] - .get() - .parity_after_puncture.XorFromRaw(&db_chunk[offset * entry_size_]); - } - // Store the replacement yacl::crypto::Prg prg(yacl::crypto::SecureRandU64()); for (uint64_t k = 0; k < backup_set_num_per_chunk_; k++) { diff --git a/experimental/pir/piano/client.h b/experiment/pir/piano/client.h similarity index 92% rename from experimental/pir/piano/client.h rename to experiment/pir/piano/client.h index 83688192..e393b3ef 100644 --- a/experimental/pir/piano/client.h +++ b/experiment/pir/piano/client.h @@ -5,13 +5,20 @@ #include #include -#include "experimental/pir/piano/serialize.h" -#include "experimental/pir/piano/util.h" +#include "experiment/pir/piano/serialize.h" +#include "experiment/pir/piano/util.h" #include "yacl/crypto/tools/prg.h" #include "yacl/link/context.h" namespace pir::piano { +// Statistical security parameter as log base 2 +constexpr uint64_t kStatisticalSecurityLog2 = 40; + +// Natural logarithm of the security parameter, ln(x) = log2(x) * ln(2) +constexpr double kStatisticalSecurityLn = + std::log(2) * static_cast(kStatisticalSecurityLog2); + struct LocalSet { /** * @brief Represents a compressed set in the primary table. @@ -108,12 +115,6 @@ class QueryServiceClient { uint64_t GetTotalQueryNumber() const { return total_query_num_; }; private: - // Statistical security parameter, representing log base 2 - const uint64_t log2_k_ = 40; - // Converts log2_k_ from base-2 to natural logarithm using the change of base - // formula - const double natural_log_k_ = std::log(2) * static_cast(log2_k_); - std::shared_ptr context_; uint64_t total_query_num_{}; uint64_t entry_num_{}; diff --git a/experimental/pir/piano/piano.proto b/experiment/pir/piano/piano.proto similarity index 100% rename from experimental/pir/piano/piano.proto rename to experiment/pir/piano/piano.proto diff --git a/experimental/pir/piano/piano_benchmark.cc b/experiment/pir/piano/piano_benchmark.cc similarity index 96% rename from experimental/pir/piano/piano_benchmark.cc rename to experiment/pir/piano/piano_benchmark.cc index 40a10063..3fafe241 100644 --- a/experimental/pir/piano/piano_benchmark.cc +++ b/experiment/pir/piano/piano_benchmark.cc @@ -6,9 +6,9 @@ #include #include "benchmark/benchmark.h" -#include "experimental/pir/piano/client.h" -#include "experimental/pir/piano/server.h" -#include "experimental/pir/piano/util.h" +#include "experiment/pir/piano/client.h" +#include "experiment/pir/piano/server.h" +#include "experiment/pir/piano/util.h" #include "yacl/link/context.h" #include "yacl/link/test_util.h" diff --git a/experimental/pir/piano/piano_test.cc b/experiment/pir/piano/piano_test.cc similarity index 97% rename from experimental/pir/piano/piano_test.cc rename to experiment/pir/piano/piano_test.cc index fcefffe8..7311e182 100644 --- a/experimental/pir/piano/piano_test.cc +++ b/experiment/pir/piano/piano_test.cc @@ -8,9 +8,9 @@ #include #include -#include "experimental/pir/piano/client.h" -#include "experimental/pir/piano/server.h" -#include "experimental/pir/piano/util.h" +#include "experiment/pir/piano/client.h" +#include "experiment/pir/piano/server.h" +#include "experiment/pir/piano/util.h" #include "gtest/gtest.h" #include "yacl/link/context.h" #include "yacl/link/test_util.h" diff --git a/experimental/pir/piano/serialize.h b/experiment/pir/piano/serialize.h similarity index 96% rename from experimental/pir/piano/serialize.h rename to experiment/pir/piano/serialize.h index ea0ce2c9..0a1ce40f 100644 --- a/experimental/pir/piano/serialize.h +++ b/experiment/pir/piano/serialize.h @@ -2,10 +2,10 @@ #include -#include "experimental/pir/piano/util.h" +#include "experiment/pir/piano/util.h" #include "yacl/base/buffer.h" -#include "experimental/pir/piano/piano.pb.h" +#include "experiment/pir/piano/piano.pb.h" namespace pir::piano { diff --git a/experimental/pir/piano/server.cc b/experiment/pir/piano/server.cc similarity index 98% rename from experimental/pir/piano/server.cc rename to experiment/pir/piano/server.cc index 6a10d300..3e93e13c 100644 --- a/experimental/pir/piano/server.cc +++ b/experiment/pir/piano/server.cc @@ -1,4 +1,4 @@ -#include "experimental/pir/piano/server.h" +#include "experiment/pir/piano/server.h" namespace pir::piano { diff --git a/experimental/pir/piano/server.h b/experiment/pir/piano/server.h similarity index 95% rename from experimental/pir/piano/server.h rename to experiment/pir/piano/server.h index dc3e1fa6..4af0b6d0 100644 --- a/experimental/pir/piano/server.h +++ b/experiment/pir/piano/server.h @@ -7,8 +7,8 @@ #include #include -#include "experimental/pir/piano/serialize.h" -#include "experimental/pir/piano/util.h" +#include "experiment/pir/piano/serialize.h" +#include "experiment/pir/piano/util.h" #include "yacl/link/context.h" namespace pir::piano { diff --git a/experimental/pir/piano/util.cc b/experiment/pir/piano/util.cc similarity index 98% rename from experimental/pir/piano/util.cc rename to experiment/pir/piano/util.cc index c9dc302f..e4813999 100644 --- a/experimental/pir/piano/util.cc +++ b/experiment/pir/piano/util.cc @@ -1,4 +1,4 @@ -#include "experimental/pir/piano/util.h" +#include "experiment/pir/piano/util.h" namespace pir::piano { diff --git a/experimental/pir/piano/util.h b/experiment/pir/piano/util.h similarity index 100% rename from experimental/pir/piano/util.h rename to experiment/pir/piano/util.h