From 4157737265d391b13129aadbd00c22f795a98073 Mon Sep 17 00:00:00 2001 From: huocun Date: Mon, 18 Nov 2024 12:01:15 +0800 Subject: [PATCH 1/3] repo-sync-2024-11-18T12:01:08+0800 --- README.md | 4 + bazel/psi.bzl | 6 +- bazel/repositories.bzl | 72 +++--- benchmark/.env | 4 - benchmark/.gitignore | 2 - benchmark/Makefile | 39 ---- benchmark/docker-compose/.env | 19 -- .../config/rr22_sender_recovery.json | 43 ---- benchmark/docker-compose/docker-compose.yml | 36 --- benchmark/docker-compose/setup_wan.sh | 5 - benchmark/plot_csv_data.py | 104 --------- benchmark/stats.py | 74 ------- docker/build.sh | 2 +- docker/entry.sh | 5 +- examples/pir/config/apsi_receiver.json | 3 +- examples/pir/config/apsi_receiver_bucket.json | 3 +- experiment/pir/pps/server.h | 7 +- psi/apsi_wrapper/api/BUILD.bazel | 22 ++ psi/apsi_wrapper/api/sender.h | 19 ++ psi/apsi_wrapper/api/wrapper_util.cc | 2 +- psi/apsi_wrapper/cli/entry.cc | 208 +++++++++--------- psi/apsi_wrapper/cli/entry.h | 1 + psi/apsi_wrapper/utils/BUILD.bazel | 2 + psi/apsi_wrapper/utils/csv_reader.cc | 114 ++++++---- psi/apsi_wrapper/utils/csv_reader.h | 11 +- psi/apsi_wrapper/utils/group_db.cc | 61 ++++- psi/apsi_wrapper/utils/group_db.h | 12 + psi/ecdh/ub_psi/ecdh_oprf_psi.cc | 4 +- psi/interface.cc | 14 ++ psi/kwpir/BUILD.bazel | 9 +- psi/kwpir/client/BUILD.bazel | 31 +++ psi/kwpir/client/kw_pir_client.cc | 18 ++ psi/kwpir/client/kw_pir_client.h | 17 ++ psi/kwpir/common/BUILD.bazel | 27 +++ psi/kwpir/common/input_provider.cc | 46 ++++ psi/kwpir/common/input_provider.h | 42 ++++ psi/kwpir/server/BUILD.bazel | 58 +++++ psi/kwpir/server/apsi_kw_pir_server.cc | 67 ++++++ psi/kwpir/server/apsi_kw_pir_server.h | 42 ++++ psi/kwpir/server/kw_pir_server.cc | 64 ++++++ psi/kwpir/server/kw_pir_server.h | 59 +++++ psi/kwpir/server/kw_seal_pir_server.cc | 27 +++ psi/kwpir/server/kw_seal_pir_server.h | 31 +++ psi/launch.cc | 3 + psi/proto/BUILD.bazel | 66 +++++- psi/proto/apsi_wrapper.proto | 134 +++++++++++ psi/proto/common.proto | 79 +++++++ psi/proto/entry.proto | 2 +- psi/proto/kw_pir_client_service.proto | 48 ++++ psi/proto/kw_pir_server_service.proto | 49 +++++ psi/proto/kw_seal_pir.proto | 28 +++ psi/proto/pir.proto | 153 +++++-------- psi/rr22/rr22_oprf.h | 1 + psi/rr22/rr22_psi.h | 2 - psi/sealpir/BUILD.bazel | 4 +- psi/utils/BUILD.bazel | 1 + psi/utils/arrow_csv_batch_provider.cc | 19 +- psi/utils/arrow_helper.h | 2 + psi/utils/csv_header_parser_test.cc | 4 +- psi/utils/join_processor.cc | 18 +- psi/utils/table_utils.cc | 10 - psi/utils/table_utils.h | 2 - 62 files changed, 1403 insertions(+), 658 deletions(-) delete mode 100644 benchmark/.env delete mode 100644 benchmark/.gitignore delete mode 100644 benchmark/Makefile delete mode 100644 benchmark/docker-compose/.env delete mode 100644 benchmark/docker-compose/config/rr22_sender_recovery.json delete mode 100644 benchmark/docker-compose/docker-compose.yml delete mode 100644 benchmark/docker-compose/setup_wan.sh delete mode 100644 benchmark/plot_csv_data.py delete mode 100644 benchmark/stats.py create mode 100644 psi/kwpir/client/BUILD.bazel create mode 100644 psi/kwpir/client/kw_pir_client.cc create mode 100644 psi/kwpir/client/kw_pir_client.h create mode 100644 psi/kwpir/common/BUILD.bazel create mode 100644 psi/kwpir/common/input_provider.cc create mode 100644 psi/kwpir/common/input_provider.h create mode 100644 psi/kwpir/server/BUILD.bazel create mode 100644 psi/kwpir/server/apsi_kw_pir_server.cc create mode 100644 psi/kwpir/server/apsi_kw_pir_server.h create mode 100644 psi/kwpir/server/kw_pir_server.cc create mode 100644 psi/kwpir/server/kw_pir_server.h create mode 100644 psi/kwpir/server/kw_seal_pir_server.cc create mode 100644 psi/kwpir/server/kw_seal_pir_server.h create mode 100644 psi/proto/apsi_wrapper.proto create mode 100644 psi/proto/common.proto create mode 100644 psi/proto/kw_pir_client_service.proto create mode 100644 psi/proto/kw_pir_server_service.proto create mode 100644 psi/proto/kw_seal_pir.proto diff --git a/README.md b/README.md index f70d7b74..002c0f9c 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,10 @@ chmod +x traceconv ``` 4. Open chrome://tracing in your chrome and load JSON file. + + + + ## PSI V2 Benchamrk Please refer to [PSI V2 Benchmark](docs/user_guide/psi_v2_benchmark.md) diff --git a/bazel/psi.bzl b/bazel/psi.bzl index 50f01510..dd9958e8 100644 --- a/bazel/psi.bzl +++ b/bazel/psi.bzl @@ -30,9 +30,9 @@ FAST_FLAGS = ["-O1"] def _psi_copts(): return select({ - "@psi//bazel:psi_build_as_release": RELEASE_FLAGS, - "@psi//bazel:psi_build_as_debug": DEBUG_FLAGS, - "@psi//bazel:psi_build_as_fast": FAST_FLAGS, + "//bazel:psi_build_as_release": RELEASE_FLAGS, + "//bazel:psi_build_as_debug": DEBUG_FLAGS, + "//bazel:psi_build_as_fast": FAST_FLAGS, "//conditions:default": FAST_FLAGS, }) + WARNING_FLAGS diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index a2e01d58..d9756a99 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -72,7 +72,7 @@ def _com_github_facebook_zstd(): maybe( http_archive, name = "com_github_facebook_zstd", - build_file = "@psi//bazel:zstd.BUILD", + build_file = "//bazel:zstd.BUILD", strip_prefix = "zstd-1.5.5", sha256 = "98e9c3d949d1b924e28e01eccb7deed865eefebf25c2f21c702e5cd5b63b85e1", type = ".tar.gz", @@ -93,7 +93,7 @@ def _upb(): ], patch_args = ["-p1"], patches = [ - "@psi//bazel/patches:upb.patch", + "//bazel/patches:upb.patch", ], ) @@ -106,14 +106,14 @@ def _com_github_emptoolkit_emp_tool(): type = "tar.gz", patch_args = ["-p1"], patches = [ - "@psi//bazel/patches:emp-tool.patch", - "@psi//bazel/patches:emp-tool-cmake.patch", - "@psi//bazel/patches:emp-tool-sse2neon.patch", + "//bazel/patches:emp-tool.patch", + "//bazel/patches:emp-tool-cmake.patch", + "//bazel/patches:emp-tool-sse2neon.patch", ], urls = [ "https://github.com/emp-toolkit/emp-tool/archive/refs/tags/0.2.5.tar.gz", ], - build_file = "@psi//bazel:emp-tool.BUILD", + build_file = "//bazel:emp-tool.BUILD", ) def _com_github_intel_ipp(): @@ -122,10 +122,10 @@ def _com_github_intel_ipp(): name = "com_github_intel_ipp", sha256 = "d70f42832337775edb022ca8ac1ac418f272e791ec147778ef7942aede414cdc", strip_prefix = "cryptography-primitives-ippcp_2021.8", - build_file = "@psi//bazel:ipp.BUILD", + build_file = "//bazel:ipp.BUILD", patch_args = ["-p1"], patches = [ - "@psi//bazel/patches:ippcp.patch", + "//bazel/patches:ippcp.patch", ], urls = [ "https://github.com/intel/cryptography-primitives/archive/refs/tags/ippcp_2021.8.tar.gz", @@ -140,11 +140,11 @@ def _com_github_microsoft_seal(): strip_prefix = "SEAL-4.1.1", type = "tar.gz", patch_args = ["-p1"], - patches = ["@psi//bazel/patches:seal.patch"], + patches = ["//bazel/patches:seal.patch"], urls = [ "https://github.com/microsoft/SEAL/archive/refs/tags/v4.1.1.tar.gz", ], - build_file = "@psi//bazel:seal.BUILD", + build_file = "//bazel:seal.BUILD", ) def _com_github_microsoft_apsi(): @@ -156,11 +156,11 @@ def _com_github_microsoft_apsi(): urls = [ "https://github.com/microsoft/APSI/archive/refs/tags/v0.11.0.tar.gz", ], - build_file = "@psi//bazel:microsoft_apsi.BUILD", + build_file = "//bazel:microsoft_apsi.BUILD", patch_args = ["-p1"], patches = [ - "@psi//bazel/patches:apsi.patch", - "@psi//bazel/patches:apsi-fourq.patch", + "//bazel/patches:apsi.patch", + "//bazel/patches:apsi-fourq.patch", ], patch_cmds = [ "rm -rf common/apsi/fourq", @@ -177,7 +177,7 @@ def _com_github_microsoft_gsl(): urls = [ "https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.tar.gz", ], - build_file = "@psi//bazel:microsoft_gsl.BUILD", + build_file = "//bazel:microsoft_gsl.BUILD", ) def _com_github_microsoft_kuku(): @@ -190,7 +190,7 @@ def _com_github_microsoft_kuku(): urls = [ "https://github.com/microsoft/Kuku/archive/refs/tags/v2.1.0.tar.gz", ], - build_file = "@psi//bazel:microsoft_kuku.BUILD", + build_file = "//bazel:microsoft_kuku.BUILD", ) def _com_google_flatbuffers(): @@ -208,7 +208,7 @@ def _com_google_flatbuffers(): "rm grpc/src/compiler/BUILD.bazel", "rm src/BUILD.bazel", ], - build_file = "@psi//bazel:flatbuffers.BUILD", + build_file = "//bazel:flatbuffers.BUILD", ) def _org_apache_arrow(): @@ -220,7 +220,7 @@ def _org_apache_arrow(): ], sha256 = "2852b21f93ee84185a9d838809c9a9c41bf6deca741bed1744e0fdba6cc19e3f", strip_prefix = "arrow-apache-arrow-10.0.0", - build_file = "@psi//bazel:arrow.BUILD", + build_file = "//bazel:arrow.BUILD", ) def _com_github_grpc_grpc(): @@ -231,7 +231,7 @@ def _com_github_grpc_grpc(): strip_prefix = "grpc-1.51.0", type = "tar.gz", patch_args = ["-p1"], - patches = ["@psi//bazel/patches:grpc.patch"], + patches = ["//bazel/patches:grpc.patch"], urls = [ "https://github.com/grpc/grpc/archive/refs/tags/v1.51.0.tar.gz", ], @@ -246,7 +246,7 @@ def _com_github_nelhage_rules_boost(): sha256 = "a7c42df432fae9db0587ff778d84f9dc46519d67a984eff8c79ae35e45f277c1", strip_prefix = "rules_boost-%s" % RULES_BOOST_COMMIT, patch_args = ["-p1"], - patches = ["@psi//bazel/patches:boost.patch"], + patches = ["//bazel/patches:boost.patch"], urls = [ "https://github.com/nelhage/rules_boost/archive/%s.tar.gz" % RULES_BOOST_COMMIT, ], @@ -261,7 +261,7 @@ def _com_github_tencent_rapidjson(): ], sha256 = "bf7ced29704a1e696fbccf2a2b4ea068e7774fa37f6d7dd4039d0787f8bed98e", strip_prefix = "rapidjson-1.1.0", - build_file = "@psi//bazel:rapidjson.BUILD", + build_file = "//bazel:rapidjson.BUILD", ) def _com_github_xtensor_xsimd(): @@ -274,14 +274,14 @@ def _com_github_xtensor_xsimd(): sha256 = "d52551360d37709675237d2a0418e28f70995b5b7cdad7c674626bcfbbf48328", type = "tar.gz", strip_prefix = "xsimd-8.1.0", - build_file = "@psi//bazel:xsimd.BUILD", + build_file = "//bazel:xsimd.BUILD", ) def _brotli(): maybe( http_archive, name = "brotli", - build_file = "@psi//bazel:brotli.BUILD", + build_file = "//bazel:brotli.BUILD", sha256 = "e720a6ca29428b803f4ad165371771f5398faba397edf6778837a18599ea13ff", strip_prefix = "brotli-1.1.0", urls = [ @@ -299,14 +299,14 @@ def _com_github_lz4_lz4(): sha256 = "030644df4611007ff7dc962d981f390361e6c97a34e5cbc393ddfbe019ffe2c1", type = "tar.gz", strip_prefix = "lz4-1.9.3", - build_file = "@psi//bazel:lz4.BUILD", + build_file = "//bazel:lz4.BUILD", ) def _org_apache_thrift(): maybe( http_archive, name = "org_apache_thrift", - build_file = "@psi//bazel:thrift.BUILD", + build_file = "//bazel:thrift.BUILD", sha256 = "31e46de96a7b36b8b8a457cecd2ee8266f81a83f8e238a9d324d8c6f42a717bc", strip_prefix = "thrift-0.21.0", urls = [ @@ -320,7 +320,7 @@ def _com_google_double_conversion(): name = "com_google_double_conversion", sha256 = "04ec44461850abbf33824da84978043b22554896b552c5fd11a9c5ae4b4d296e", strip_prefix = "double-conversion-3.3.0", - build_file = "@psi//bazel:double-conversion.BUILD", + build_file = "//bazel:double-conversion.BUILD", urls = [ "https://github.com/google/double-conversion/archive/refs/tags/v3.3.0.tar.gz", ], @@ -330,7 +330,7 @@ def _bzip2(): maybe( http_archive, name = "bzip2", - build_file = "@psi//bazel:bzip2.BUILD", + build_file = "//bazel:bzip2.BUILD", sha256 = "ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269", strip_prefix = "bzip2-1.0.8", urls = [ @@ -347,7 +347,7 @@ def _com_github_google_snappy(): ], sha256 = "75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7", strip_prefix = "snappy-1.1.9", - build_file = "@psi//bazel:snappy.BUILD", + build_file = "//bazel:snappy.BUILD", ) def _com_github_google_perfetto(): @@ -360,8 +360,8 @@ def _com_github_google_perfetto(): sha256 = "4c8fe8a609fcc77ca653ec85f387ab6c3a048fcd8df9275a1aa8087984b89db8", strip_prefix = "perfetto-41.0", patch_args = ["-p1"], - patches = ["@psi//bazel/patches:perfetto.patch"], - build_file = "@psi//bazel:perfetto.BUILD", + patches = ["//bazel/patches:perfetto.patch"], + build_file = "//bazel:perfetto.BUILD", ) def _com_github_floodyberry_curve25519_donna(): @@ -371,7 +371,7 @@ def _com_github_floodyberry_curve25519_donna(): strip_prefix = "curve25519-donna-2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2", sha256 = "ba57d538c241ad30ff85f49102ab2c8dd996148456ed238a8c319f263b7b149a", type = "tar.gz", - build_file = "@psi//bazel:curve25519-donna.BUILD", + build_file = "//bazel:curve25519-donna.BUILD", urls = [ "https://github.com/floodyberry/curve25519-donna/archive/2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2.tar.gz", ], @@ -386,7 +386,7 @@ def _com_github_ridiculousfish_libdivide(): ], sha256 = "01ffdf90bc475e42170741d381eb9cfb631d9d7ddac7337368bcd80df8c98356", strip_prefix = "libdivide-5.0", - build_file = "@psi//bazel:libdivide.BUILD", + build_file = "//bazel:libdivide.BUILD", ) def _com_github_sparsehash_sparsehash(): @@ -398,14 +398,14 @@ def _com_github_sparsehash_sparsehash(): ], sha256 = "8cd1a95827dfd8270927894eb77f62b4087735cbede953884647f16c521c7e58", strip_prefix = "sparsehash-sparsehash-2.0.4", - build_file = "@psi//bazel:sparsehash.BUILD", + build_file = "//bazel:sparsehash.BUILD", ) def _com_github_zeromq_cppzmq(): maybe( http_archive, name = "com_github_zeromq_cppzmq", - build_file = "@psi//bazel:cppzmq.BUILD", + build_file = "//bazel:cppzmq.BUILD", strip_prefix = "cppzmq-4.10.0", sha256 = "c81c81bba8a7644c84932225f018b5088743a22999c6d82a2b5f5cd1e6942b74", type = ".tar.gz", @@ -418,7 +418,7 @@ def _com_github_zeromq_libzmq(): maybe( http_archive, name = "com_github_zeromq_libzmq", - build_file = "@psi//bazel:libzmq.BUILD", + build_file = "//bazel:libzmq.BUILD", strip_prefix = "libzmq-4.3.5", sha256 = "6c972d1e6a91a0ecd79c3236f04cf0126f2f4dfbbad407d72b4606a7ba93f9c6", type = ".tar.gz", @@ -431,7 +431,7 @@ def _com_github_log4cplus_log4cplus(): maybe( http_archive, name = "com_github_log4cplus_log4cplus", - build_file = "@psi//bazel:log4cplus.BUILD", + build_file = "//bazel:log4cplus.BUILD", strip_prefix = "log4cplus-2.1.1", sha256 = "42dc435928917fd2f847046c4a0c6086b2af23664d198c7fc1b982c0bfe600c1", type = ".tar.gz", @@ -444,7 +444,7 @@ def _com_github_open_source_parsers_jsoncpp(): maybe( http_archive, name = "com_github_open_source_parsers_jsoncpp", - build_file = "@psi//bazel:jsoncpp.BUILD", + build_file = "//bazel:jsoncpp.BUILD", strip_prefix = "jsoncpp-1.9.6", sha256 = "f93b6dd7ce796b13d02c108bc9f79812245a82e577581c4c9aabe57075c90ea2", type = ".tar.gz", diff --git a/benchmark/.env b/benchmark/.env deleted file mode 100644 index ddceac6a..00000000 --- a/benchmark/.env +++ /dev/null @@ -1,4 +0,0 @@ -RECEIVER_ITEM_CNT=1e8 -SENDER_ITEM_CNT=1e8 -INTERSECTION_CNT=8e6 -ID_CNT=2 diff --git a/benchmark/.gitignore b/benchmark/.gitignore deleted file mode 100644 index c996948d..00000000 --- a/benchmark/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -docker-compose/tmp_data/* -docker-compose/logs/* \ No newline at end of file diff --git a/benchmark/Makefile b/benchmark/Makefile deleted file mode 100644 index 50569ce2..00000000 --- a/benchmark/Makefile +++ /dev/null @@ -1,39 +0,0 @@ -include docker-compose/.env -include .env - -USER := $(shell whoami) -DOCKER_PROJ_NAME := ${DOCKER_PROJ_PREFIX}_${USER} - - -.PHONY: start-docker clean mock-data analysis - -default: all - -clean: - @[ -d "$(PWD)/docker-compose/logs" ] && rm -rf $(PWD)/docker-compose/logs && echo "Directory removed." || echo "Directory not exists." - @[ -d "$(PWD)/docker-compose/tmp_data" ] && rm -rf $(PWD)/docker-compose/tmp_data && echo "Directory removed." || echo "Directory not exists." - @(docker compose -p ${DOCKER_PROJ_NAME} down) - -mock-data: - [ ! -d "$(PWD)/docker-compose/tmp_data" ] && mkdir -p "$(PWD)/docker-compose/tmp_data" && echo "Directory created." || echo "Directory already exists." - @(python $(PWD)/../examples/psi/generate_psi_data.py --receiver_item_cnt ${RECEIVER_ITEM_CNT} \ - --sender_item_cnt ${SENDER_ITEM_CNT} --intersection_cnt ${INTERSECTION_CNT} --id_cnt ${ID_CNT} \ - --receiver_path docker-compose/tmp_data/receiver_input.csv --sender_path docker-compose/tmp_data/sender_input.csv \ - --intersection_path docker-compose/tmp_data/intersection.csv) - -all: clean mock-data start-docker analysis - @echo "well done!" - -analysis: - @[ ! -d "$(PWD)/docker-compose/logs" ] && mkdir -p "$(PWD)/docker-compose/logs" && echo "Directory created." || echo "Directory already exists." - @[ ! -d "$(PWD)/docker-compose/logs/receiver" ] && mkdir -p "$(PWD)/docker-compose/logs/receiver" && echo "Directory created." || echo "Directory already exists." - @[ ! -d "$(PWD)/docker-compose/logs/sender" ] && mkdir -p "$(PWD)/docker-compose/logs/sender" && echo "Directory created." || echo "Directory already exists." - nohup python $(PWD)/stats.py $(DOCKER_PROJ_NAME)-psi-sender-1 $(PWD)/docker-compose/logs/sender/stats.csv >/dev/null 2>&1 & - python $(PWD)/stats.py $(DOCKER_PROJ_NAME)-psi-receiver-1 $(PWD)/docker-compose/logs/receiver/stats.csv - python $(PWD)/plot_csv_data.py $(PWD)/docker-compose/logs/receiver/stats.csv $(PWD)/docker-compose/logs/receiver - python $(PWD)/plot_csv_data.py $(PWD)/docker-compose/logs/sender/stats.csv $(PWD)/docker-compose/logs/sender - docker logs $(DOCKER_PROJ_NAME)-psi-sender-1 > $(PWD)/docker-compose/logs/sender/psi.log - docker logs $(DOCKER_PROJ_NAME)-psi-receiver-1 > $(PWD)/docker-compose/logs/receiver/psi.log - -start-docker: - @(cd $(PWD)/docker-compose && docker compose -p ${DOCKER_PROJ_NAME} up -d) diff --git a/benchmark/docker-compose/.env b/benchmark/docker-compose/.env deleted file mode 100644 index c9edf139..00000000 --- a/benchmark/docker-compose/.env +++ /dev/null @@ -1,19 +0,0 @@ -# OPENSOURCE-CLEANUP GSUB psi:latest secretflow/psi:latest -# docker env -IMAGE_WITH_TAG=secretflow/psi-anolis8:0.4.2b0 - -# network env -# LATENCY=10ms -# BANDWIDTH=100mbit - -# cpu/memory -ALICE_CPU_LIMIT=64 -BOB_CPU_LIMIT=64 -ALICE_MEMORY_LIMIT=8G -BOB_MEMORY_LIMIT=8G - -# other -DOCKER_PROJ_PREFIX=psi_bench - -SENDER_RUN="/root/main --config /home/admin/psi/conf/rr22_sender_recovery.json" -RECEIVER_RUN="/root/main --config /home/admin/psi/conf/rr22_receiver_recovery.json" diff --git a/benchmark/docker-compose/config/rr22_sender_recovery.json b/benchmark/docker-compose/config/rr22_sender_recovery.json deleted file mode 100644 index a3f18abf..00000000 --- a/benchmark/docker-compose/config/rr22_sender_recovery.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "psi_config": { - "protocol_config": { - "protocol": "PROTOCOL_RR22", - "role": "ROLE_SENDER", - "broadcast_result": false, - "rr22_config": { - "bucket_size": 1000000 - } - }, - "input_config": { - "type": "IO_TYPE_FILE_CSV", - "path": "/data/sender_input.csv" - }, - "output_config": { - "type": "IO_TYPE_FILE_CSV", - "path": "/tmp/rr22_sender_recovery_output.csv" - }, - "keys": ["id_0", "id_1"], - "debug_options": { - "trace_path": "/tmp/rr22_sender_recovery.trace" - }, - "skip_duplicates_check": true, - "disable_alignment": true, - "recovery_config": { - "enabled": true, - "folder": "/tmp/rr22_sender_cache" - } - }, - "link_config": { - "parties": [ - { - "id": "receiver", - "host": "psi-receiver:5300" - }, - { - "id": "sender", - "host": "0.0.0.0:5300" - } - ] - }, - "self_link_party": "sender" -} diff --git a/benchmark/docker-compose/docker-compose.yml b/benchmark/docker-compose/docker-compose.yml deleted file mode 100644 index 4e522bf5..00000000 --- a/benchmark/docker-compose/docker-compose.yml +++ /dev/null @@ -1,36 +0,0 @@ -version: '3.8' -services: - psi-sender: - entrypoint: - - bash - - -c - - "bash /data/setup_wan.sh && ${SENDER_RUN}" - image: ${IMAGE_WITH_TAG} - cap_add: - - NET_ADMIN - volumes: - - ./config/:/home/admin/psi/conf/ - - ./tmp_data/:/data - - ./setup_wan.sh:/data/setup_wan.sh - deploy: - resources: - limits: - cpus: '${ALICE_CPU_LIMIT}' - memory: ${ALICE_MEMORY_LIMIT} - psi-receiver: - entrypoint: - - bash - - -c - - "bash /data/setup_wan.sh && ${RECEIVER_RUN}" - image: ${IMAGE_WITH_TAG} - cap_add: - - NET_ADMIN - volumes: - - ./config/:/home/admin/psi/conf/ - - ./tmp_data/:/data - - ./setup_wan.sh:/data/setup_wan.sh - deploy: - resources: - limits: - cpus: '${BOB_CPU_LIMIT}' - memory: ${BOB_MEMORY_LIMIT} diff --git a/benchmark/docker-compose/setup_wan.sh b/benchmark/docker-compose/setup_wan.sh deleted file mode 100644 index 3a34e273..00000000 --- a/benchmark/docker-compose/setup_wan.sh +++ /dev/null @@ -1,5 +0,0 @@ -set -eu - -yum install iproute-tc -y; -tc qdisc add dev eth0 root handle 1: tbf rate 100mbit burst 128kb latency 10ms; -tc qdisc add dev eth0 parent 1:1 handle 10: netem delay 10msec limit 8000 diff --git a/benchmark/plot_csv_data.py b/benchmark/plot_csv_data.py deleted file mode 100644 index 305a0b42..00000000 --- a/benchmark/plot_csv_data.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2024 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd -import matplotlib.pyplot as plt -import sys -import os - - -def plot_cpu(docker_csv_path, output_path): - df1 = pd.read_csv(docker_csv_path) - - plt.plot(df1["running_time_s"], df1["cpu_percent"], marker="o", linestyle="-", color="b") - max_time_count = 10 - interval = 1 - if len(df1) > max_time_count: - interval = len(df1) // max_time_count - for i, row in df1.iterrows(): - if i % interval == 0: - plt.text( - row["running_time_s"] + 1, - 0, - str(row["time"]), - fontsize=9, - ha="right", - rotation=45, - ) - plt.title("cpu over Time") - plt.xlabel("running time sec") - plt.ylabel("cpu") - plt.grid(True) - - plt.savefig(output_path) - plt.clf() - -def plot_mem(docker_csv_path, output_path): - df1 = pd.read_csv(docker_csv_path) - - plt.plot(df1["running_time_s"], df1["mem_usage_MB"], marker="o", linestyle="-", color="b") - max_time_count = 10 - interval = 1 - if len(df1) > max_time_count: - interval = len(df1) // max_time_count - for i, row in df1.iterrows(): - if i % interval == 0: - plt.text( - row["running_time_s"] + 1, - 0, - str(row["time"]), - fontsize=9, - ha="right", - rotation=45, - ) - plt.title("memory over Time") - plt.xlabel("running time sec") - plt.ylabel("memory MB") - plt.grid(True) - - plt.savefig(output_path) - plt.clf() - -def plot_net(docker_csv_path, output_path): - df1 = pd.read_csv(docker_csv_path) - - plt.plot(df1["running_time_s"], df1["net_tx_kb"], marker="o", linestyle="-", color="b") - plt.plot(df1["running_time_s"], df1["net_rx_kb"], marker="*", linestyle="-", color="y") - max_time_count = 10 - interval = 1 - if len(df1) > max_time_count: - interval = len(df1) // max_time_count - for i, row in df1.iterrows(): - if i % interval == 0: - plt.text( - row["running_time_s"] + 1, - 0, - str(row["time"]), - fontsize=9, - ha="right", - rotation=45, - ) - plt.title("network over Time") - plt.xlabel("running time sec") - plt.ylabel("network kb") - plt.grid(True) - - plt.savefig(output_path) - plt.clf() - - -if __name__ == "__main__": - plot_cpu(sys.argv[1], os.path.join(sys.argv[2], "cpu.png")) - plot_mem(sys.argv[1], os.path.join(sys.argv[2], "mem.png")) - plot_net(sys.argv[1], os.path.join(sys.argv[2], "net.png")) diff --git a/benchmark/stats.py b/benchmark/stats.py deleted file mode 100644 index 8f4bc16b..00000000 --- a/benchmark/stats.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import docker -import json -import csv -import sys -import time -from datetime import datetime - -def stream_container_stats(container_name, output_file): - client = docker.from_env() - - try: - container = client.containers.get(container_name) - stats_stream = container.stats(stream=True) - - with open(output_file, 'w', newline='') as csvfile: - fieldnames = ['cpu_percent', 'mem_usage_MB', 'mem_limit_MB', 'net_tx_kb', 'net_rx_kb', 'running_time_s', 'time'] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - - writer.writeheader() - prev_net_tx = 0 - prev_net_rx = 0 - prev_cpu_total = 0 - prev_cpu_system = 0 - start_unix_time = int(time.time()) - for stats in stats_stream: - data = json.loads(stats) - running_time_s = int(time.time()) - start_unix_time - cpu_percent = ((data['cpu_stats']['cpu_usage']['total_usage'] - prev_cpu_total) / - (data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100 - mem_usage = (data['memory_stats']['usage'] - data['memory_stats']['stats']['inactive_file']) / 1024 / 1024 - mem_limit = data['memory_stats']['limit'] / 1024 / 1024 - net_tx = 0 - net_rx = 0 - for key, value in data['networks'].items(): - net_tx += value['tx_bytes'] / 1024 - net_rx += value['rx_bytes'] / 1024 - # skip first five seconds, due to running setting up network - if running_time_s > 5: - writer.writerow({ - 'cpu_percent': cpu_percent, - 'mem_usage_MB': int(mem_usage), - 'mem_limit_MB': int(mem_limit), - 'net_tx_kb': int((net_tx - prev_net_tx) * 8), - 'net_rx_kb': int((net_rx - prev_net_rx) * 8), - 'running_time_s': running_time_s, - 'time': datetime.fromtimestamp(time.time()).strftime('%H:%M:%S') - }) - prev_net_tx = net_tx - prev_net_rx = net_rx - prev_cpu_total = data['cpu_stats']['cpu_usage']['total_usage'] - prev_cpu_system = data['cpu_stats']['system_cpu_usage'] - - except docker.errors.NotFound: - print(f"Container {container_name} not found.") - except Exception as e: - if container.status != 'exited': - print(f"An error occurred: {e} container.status: {container.status}") - -if __name__ == "__main__": - stream_container_stats(sys.argv[1], sys.argv[2]) \ No newline at end of file diff --git a/docker/build.sh b/docker/build.sh index eb88fc64..a9d7bffd 100644 --- a/docker/build.sh +++ b/docker/build.sh @@ -72,7 +72,7 @@ echo -e "Build psi binary ${GREEN}PSI ${PSI_VERSION}${NO_COLOR}..." SCRIPT_DIR="$(realpath $(dirname $0))" if [[ SKIP -eq 0 ]]; then - docker run -it --rm --mount type=bind,source="${SCRIPT_DIR}/../",target=/home/admin/dev/src -w /home/admin/dev --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow/release-ci:latest /home/admin/dev/src/docker/entry.sh + docker run -it --rm --mount type=bind,source="${SCRIPT_DIR}/../",target=/home/admin/dev/src -w /home/admin/dev --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --cap-add=NET_ADMIN --privileged=true secretflow/release-ci:1.7 /home/admin/dev/src/docker/entry.sh echo -e "Finish building psi binary ${GREEN}${IMAGE_LITE_TAG}${NO_COLOR}" fi diff --git a/docker/entry.sh b/docker/entry.sh index e53ae0af..fa4c2861 100755 --- a/docker/entry.sh +++ b/docker/entry.sh @@ -7,7 +7,10 @@ cd src_copied conda install -y perl=5.20.3.1 -bazel build psi:main -c opt --config=linux-release --repository_cache=/tmp/bazel_repo_cache + + + +bazel build psi:main -c opt --config=linux-release --repository_cache=/tmp/bazel_repo_cache --remote_timeout=300s --remote_retries=10 chmod 777 bazel-bin/psi/main mkdir -p ../src/docker/linux/amd64 cp bazel-bin/psi/main ../src/docker/linux/amd64 diff --git a/examples/pir/config/apsi_receiver.json b/examples/pir/config/apsi_receiver.json index 2c7f56ba..60863a78 100644 --- a/examples/pir/config/apsi_receiver.json +++ b/examples/pir/config/apsi_receiver.json @@ -4,7 +4,8 @@ "query_file": "/tmp/query.csv", "output_file": "/tmp/result.csv", "params_file": "/tmp/100K-1-16.json", - "log_level": "info" + "log_level": "info", + "query_batch_size": 1 }, "link_config": { "parties": [ diff --git a/examples/pir/config/apsi_receiver_bucket.json b/examples/pir/config/apsi_receiver_bucket.json index 6fc944db..bbfa2bf5 100644 --- a/examples/pir/config/apsi_receiver_bucket.json +++ b/examples/pir/config/apsi_receiver_bucket.json @@ -4,7 +4,8 @@ "output_file": "/tmp/result.csv", "params_file": "/tmp/100K-1-16.json", "experimental_enable_bucketize": true, - "experimental_bucket_cnt": 10000 + "experimental_bucket_cnt": 10000, + "query_batch_size": 1 }, "link_config": { "parties": [ diff --git a/experiment/pir/pps/server.h b/experiment/pir/pps/server.h index cfd3b60a..4e2a2792 100644 --- a/experiment/pir/pps/server.h +++ b/experiment/pir/pps/server.h @@ -22,10 +22,12 @@ namespace pir::pps { class PpsPirServer { public: - PpsPirServer() : pps_(), universe_size_(0) {} + PpsPirServer() : pps_(), universe_size_(0), set_size_(0) {} PpsPirServer(uint64_t universe_size, uint64_t set_size) - : pps_(universe_size, set_size), universe_size_(universe_size) {} + : pps_(universe_size, set_size), + universe_size_(universe_size), + set_size_(set_size) {} void Hint(PIRKey k, std::set& deltas, yacl::dynamic_bitset<>& bits, yacl::dynamic_bitset<>& h); @@ -41,5 +43,6 @@ class PpsPirServer { private: PPS pps_; uint64_t universe_size_; + uint64_t set_size_; }; } // namespace pir::pps diff --git a/psi/apsi_wrapper/api/BUILD.bazel b/psi/apsi_wrapper/api/BUILD.bazel index 02cefded..9bf56d5a 100644 --- a/psi/apsi_wrapper/api/BUILD.bazel +++ b/psi/apsi_wrapper/api/BUILD.bazel @@ -119,3 +119,25 @@ psi_cc_test( "@boost//:uuid", ], ) + +exports_files( + [ + "exported_symbols.lds", + ], + visibility = ["//visibility:private"], +) + +cc_shared_library( + name = "wrapper_shared", + additional_linker_inputs = [ + ":exported_symbols.lds", + ], + shared_lib_name = "wrapper.so", + user_link_flags = [ + "-Wl,--version-script=$(location :exported_symbols.lds)", + ], + deps = [ + ":receiver_c_wrapper", + ":sender_c_wrapper", + ], +) diff --git a/psi/apsi_wrapper/api/sender.h b/psi/apsi_wrapper/api/sender.h index d92c68cf..fe2786e0 100644 --- a/psi/apsi_wrapper/api/sender.h +++ b/psi/apsi_wrapper/api/sender.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include "apsi/query.h" @@ -31,6 +32,7 @@ namespace psi::apsi_wrapper::api { class Sender { public: + // TODO(huocun): delete source_file option struct Option { std::string source_file; std::string db_path; @@ -40,11 +42,21 @@ class Sender { bool compress = true; std::string params_file; }; + struct KwPirOption { + std::shared_ptr provider; + std::string db_path; + size_t group_cnt = 1; + size_t num_buckets = 1; + uint32_t nonce_byte_count = 16; + bool compress = true; + std::string params_file; + }; public: Sender(std::string db_path, size_t thread_count = std::thread::hardware_concurrency()) : group_db_(db_path), thread_count_(thread_count) {} + Sender(Option option, size_t thread_count = std::thread::hardware_concurrency()) : group_db_(option.source_file, option.db_path, option.group_cnt, @@ -52,6 +64,13 @@ class Sender { option.params_file, option.compress), thread_count_(thread_count) {} + Sender(KwPirOption option, + size_t thread_count = std::thread::hardware_concurrency()) + : group_db_(option.provider, option.db_path, option.group_cnt, + option.num_buckets, option.nonce_byte_count, + option.params_file, option.compress), + thread_count_(thread_count) {} + void SetThreadCount(size_t threads); // Save sender db as file. diff --git a/psi/apsi_wrapper/api/wrapper_util.cc b/psi/apsi_wrapper/api/wrapper_util.cc index 6b19678f..3d7eeeac 100644 --- a/psi/apsi_wrapper/api/wrapper_util.cc +++ b/psi/apsi_wrapper/api/wrapper_util.cc @@ -17,7 +17,7 @@ #include "psi/apsi_wrapper/api/wrapper_util.h" -#include +#include #include "yacl/base/exception.h" diff --git a/psi/apsi_wrapper/cli/entry.cc b/psi/apsi_wrapper/cli/entry.cc index 1793fdde..fac053f7 100644 --- a/psi/apsi_wrapper/cli/entry.cc +++ b/psi/apsi_wrapper/cli/entry.cc @@ -166,54 +166,120 @@ int RunReceiver(const ReceiverOptions &options, psi::apsi_wrapper::Receiver receiver(*params); - auto [query_data, orig_items] = - psi::apsi_wrapper::load_db_with_orig_items(options.query_file); - - if (!query_data || - !holds_alternative(*query_data)) { - // Failed to read query file - SPDLOG_ERROR("Failed to read query file: terminating"); - return -1; - } + psi::apsi_wrapper::DBData db_data; + std::vector orig_items; + + psi::apsi_wrapper::ApsiCsvReader reader(options.query_file, + options.query_batch_size); + bool append_to_outfile = false; - auto &items = get(*query_data); + while (true) { + tie(db_data, orig_items) = reader.read_batch(); + if (orig_items.empty()) { + break; + } - if (options.experimental_enable_bucketize) { - std::unordered_map< - size_t, std::pair, std::vector>> - bucket_item_map; + if (!holds_alternative(db_data)) { + // Failed to read query file + SPDLOG_ERROR("Failed to read query file: terminating"); + return -1; + } - for (size_t i = 0; i < orig_items.size(); i++) { - int bucket_idx = std::hash()(orig_items[i]) % - options.experimental_bucket_cnt; + auto &items = get(db_data); - if (bucket_item_map.find(bucket_idx) == bucket_item_map.end()) { - bucket_item_map[bucket_idx] = std::make_pair( - std::vector<::apsi::Item>(), std::vector()); + if (options.experimental_enable_bucketize) { + std::unordered_map, + std::vector>> + bucket_item_map; + + for (size_t i = 0; i < orig_items.size(); i++) { + int bucket_idx = std::hash()(orig_items[i]) % + options.experimental_bucket_cnt; + + if (bucket_item_map.find(bucket_idx) == bucket_item_map.end()) { + bucket_item_map[bucket_idx] = std::make_pair( + std::vector<::apsi::Item>(), std::vector()); + } + + bucket_item_map[bucket_idx].first.emplace_back(items[i]); + bucket_item_map[bucket_idx].second.emplace_back(orig_items[i]); } - bucket_item_map[bucket_idx].first.emplace_back(items[i]); - bucket_item_map[bucket_idx].second.emplace_back(orig_items[i]); - } + size_t total_matches = 0; + + double total_time = 0; + for (const auto &pair : bucket_item_map) { + const size_t bucket_idx = pair.first; + const vector<::apsi::Item> &items_vec = pair.second.first; + auto bucket_start = std::chrono::high_resolution_clock::now(); + SPDLOG_INFO("Start deal with bucket {}", bucket_idx); + + vector<::apsi::HashedItem> oprf_items; + vector<::apsi::LabelKey> label_keys; + try { + SPDLOG_INFO("Sending OPRF request for {} {}", items_vec.size(), + " items "); + tie(oprf_items, label_keys) = + psi::apsi_wrapper::Receiver::RequestOPRF(items_vec, *channel, + bucket_idx); + SPDLOG_INFO("Received OPRF response for {} items", items_vec.size()); + } catch (const exception &ex) { + SPDLOG_WARN("OPRF request failed: {}", ex.what()); + return -1; + } + + vector<::apsi::receiver::MatchRecord> query_result; + try { + SPDLOG_INFO("Sending APSI query"); + query_result = + receiver.request_query(oprf_items, label_keys, *channel, + options.streaming_result, bucket_idx); + SPDLOG_INFO("Received APSI query response"); + } catch (const exception &ex) { + SPDLOG_WARN("Failed sending APSI query: {}", ex.what()); + return -1; + } + + auto bucket_time = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - bucket_start) + .count(); + total_time += bucket_time; + SPDLOG_INFO("End deal with bucket {}, time: {}ms", bucket_idx, + bucket_time); + + int cnt = psi::apsi_wrapper::print_intersection_results( + pair.second.second, items_vec, query_result, options.output_file, + append_to_outfile); + + if (cnt > 0 && !append_to_outfile) { + append_to_outfile = true; + } + + total_matches += cnt; + } + + SPDLOG_INFO("Average bucket time: {}ms/bucket", + total_time / bucket_item_map.size()); - bool append_to_outfile = false; - size_t total_matches = 0; + if (match_cnt != nullptr) { + *match_cnt = total_matches; + } - double total_time = 0; - for (const auto &pair : bucket_item_map) { - const size_t bucket_idx = pair.first; - const vector<::apsi::Item> &items_vec = pair.second.first; - auto bucket_start = std::chrono::high_resolution_clock::now(); - SPDLOG_INFO("Start deal with bucket {}", bucket_idx); + SPDLOG_INFO("Total matches {} items.", total_matches); + print_transmitted_data(*channel); + print_timing_report(::apsi::util::recv_stopwatch); + + } else { + vector<::apsi::Item> items_vec(items.begin(), items.end()); vector<::apsi::HashedItem> oprf_items; vector<::apsi::LabelKey> label_keys; try { - SPDLOG_INFO("Sending OPRF request for {} {}", items_vec.size(), - " items "); - tie(oprf_items, label_keys) = psi::apsi_wrapper::Receiver::RequestOPRF( - items_vec, *channel, bucket_idx); - SPDLOG_INFO("Received OPRF response for {} items", items_vec.size()); + SPDLOG_INFO("Sending OPRF request for {} items ", items_vec.size()); + tie(oprf_items, label_keys) = + psi::apsi_wrapper::Receiver::RequestOPRF(items_vec, *channel); + SPDLOG_INFO("Received OPRF response for {} items", items_vec.size()); } catch (const exception &ex) { SPDLOG_WARN("OPRF request failed: {}", ex.what()); return -1; @@ -222,80 +288,24 @@ int RunReceiver(const ReceiverOptions &options, vector<::apsi::receiver::MatchRecord> query_result; try { SPDLOG_INFO("Sending APSI query"); - query_result = - receiver.request_query(oprf_items, label_keys, *channel, - options.streaming_result, bucket_idx); + query_result = receiver.request_query(oprf_items, label_keys, *channel, + options.streaming_result); SPDLOG_INFO("Received APSI query response"); } catch (const exception &ex) { SPDLOG_WARN("Failed sending APSI query: {}", ex.what()); return -1; } - auto bucket_time = - std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - bucket_start) - .count(); - total_time += bucket_time; - SPDLOG_INFO("End deal with bucket {}, time: {}ms", bucket_idx, - bucket_time); - int cnt = psi::apsi_wrapper::print_intersection_results( - pair.second.second, items_vec, query_result, options.output_file, - append_to_outfile); + orig_items, items_vec, query_result, options.output_file); - if (cnt > 0 && !append_to_outfile) { - append_to_outfile = true; + if (match_cnt != nullptr) { + *match_cnt = cnt; } - total_matches += cnt; - } - - SPDLOG_INFO("Average bucket time: {}ms/bucket", - total_time / bucket_item_map.size()); - - if (match_cnt != nullptr) { - *match_cnt = total_matches; + print_transmitted_data(*channel); + print_timing_report(::apsi::util::recv_stopwatch); } - - SPDLOG_INFO("Total matches {} items.", total_matches); - - print_transmitted_data(*channel); - print_timing_report(::apsi::util::recv_stopwatch); - - } else { - vector<::apsi::Item> items_vec(items.begin(), items.end()); - vector<::apsi::HashedItem> oprf_items; - vector<::apsi::LabelKey> label_keys; - try { - SPDLOG_INFO("Sending OPRF request for {} items ", items_vec.size()); - tie(oprf_items, label_keys) = - psi::apsi_wrapper::Receiver::RequestOPRF(items_vec, *channel); - SPDLOG_INFO("Received OPRF response for {} items", items_vec.size()); - } catch (const exception &ex) { - SPDLOG_WARN("OPRF request failed: {}", ex.what()); - return -1; - } - - vector<::apsi::receiver::MatchRecord> query_result; - try { - SPDLOG_INFO("Sending APSI query"); - query_result = receiver.request_query(oprf_items, label_keys, *channel, - options.streaming_result); - SPDLOG_INFO("Received APSI query response"); - } catch (const exception &ex) { - SPDLOG_WARN("Failed sending APSI query: {}", ex.what()); - return -1; - } - - int cnt = psi::apsi_wrapper::print_intersection_results( - orig_items, items_vec, query_result, options.output_file); - - if (match_cnt != nullptr) { - *match_cnt = cnt; - } - - print_transmitted_data(*channel); - print_timing_report(::apsi::util::recv_stopwatch); } // NOTE(junfeng): Yacl channel need to send a empty oprf request with max diff --git a/psi/apsi_wrapper/cli/entry.h b/psi/apsi_wrapper/cli/entry.h index cc0b2ff1..aa426814 100644 --- a/psi/apsi_wrapper/cli/entry.h +++ b/psi/apsi_wrapper/cli/entry.h @@ -51,6 +51,7 @@ struct ReceiverOptions { // experimental bucketize bool experimental_enable_bucketize = false; size_t experimental_bucket_cnt; + size_t query_batch_size; }; struct SenderOptions { diff --git a/psi/apsi_wrapper/utils/BUILD.bazel b/psi/apsi_wrapper/utils/BUILD.bazel index dc9fef83..8fde25ba 100644 --- a/psi/apsi_wrapper/utils/BUILD.bazel +++ b/psi/apsi_wrapper/utils/BUILD.bazel @@ -43,6 +43,7 @@ psi_cc_library( hdrs = ["csv_reader.h"], deps = [ ":common", + "//psi/utils:arrow_csv_batch_provider", "//psi/utils:multiplex_disk_cache", "@org_apache_arrow//:arrow", ], @@ -65,6 +66,7 @@ psi_cc_library( ":csv_reader", ":group_db_status_cc_proto", ":sender_db", + "//psi/kwpir/common:input_provider", ], ) diff --git a/psi/apsi_wrapper/utils/csv_reader.cc b/psi/apsi_wrapper/utils/csv_reader.cc index 10f007d2..6c81e87a 100644 --- a/psi/apsi_wrapper/utils/csv_reader.cc +++ b/psi/apsi_wrapper/utils/csv_reader.cc @@ -17,6 +17,8 @@ #include "psi/apsi_wrapper/utils/csv_reader.h" +#include + #include "fmt/format.h" #include "fmt/ranges.h" @@ -27,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -75,7 +78,8 @@ std::vector GetCsvColumnNames(const std::string& filename) { return column_names; } -ApsiCsvReader::ApsiCsvReader(const string& file_name) : file_name_(file_name) { +ApsiCsvReader::ApsiCsvReader(const string& file_name, size_t batch_size) + : file_name_(file_name), batch_size_(batch_size) { throw_if_file_invalid(file_name_); std::vector column_names = GetCsvColumnNames(file_name_); @@ -83,8 +87,17 @@ ApsiCsvReader::ApsiCsvReader(const string& file_name) : file_name_(file_name) { for (auto& col : column_names) { column_types_[col] = arrow::utf8(); } - reader_ = MakeArrowCsvReader(file_name_, column_names); + + std::vector keys{"key"}; + std::vector values{"value"}; + if (column_names.size() == 1) { + batch_provider_ = + std::make_shared(file_name_, keys, batch_size_); + } else { + batch_provider_ = std::make_shared( + file_name_, keys, batch_size_, values); + } } std::shared_ptr ApsiCsvReader::schema() const { @@ -96,56 +109,25 @@ auto ApsiCsvReader::read() -> pair> { DBData result; vector orig_items; - std::shared_ptr batch; - - bool result_type_decided = false; while (true) { // Attempt to read the first RecordBatch - arrow::Status status = reader_->ReadNext(&batch); + auto [batch_db, batch_orig_items] = read_batch(); - if (!status.ok()) { - APSI_LOG_ERROR("Read csv error."); - } - - if (batch == nullptr) { - // Handle end of file + if (batch_orig_items.empty()) { break; } - arrays_.clear(); - - for (int i = 0; i < min(2, batch->num_columns()); i++) { - arrays_.emplace_back( - std::dynamic_pointer_cast(batch->column(i))); - } - - row_cnt += batch->num_rows(); + row_cnt += batch_orig_items.size(); - if (!result_type_decided) { - result_type_decided = true; - - if (batch->num_columns() >= 2) { - result = LabeledData{}; - } else { - result = UnlabeledData{}; - } - } - - for (int i = 0; i < batch->num_rows(); i++) { - orig_items.emplace_back(arrays_[0]->Value(i)); + for (size_t i = 0; i < batch_orig_items.size(); i++) { + orig_items.emplace_back(batch_orig_items[i]); - if (holds_alternative(result)) { + if (holds_alternative(batch_db)) { get(result).emplace_back( - std::string(arrays_[0]->Value(i))); - } else if (holds_alternative(result)) { - Label label; - label.reserve(arrays_[1]->Value(i).size()); - copy(arrays_[1]->Value(i).begin(), arrays_[1]->Value(i).end(), - back_inserter(label)); - - get(result).emplace_back(std::string(arrays_[0]->Value(i)), - label); + get(batch_db)[i]); + } else if (holds_alternative(batch_db)) { + get(result).emplace_back(get(batch_db)[i]); } else { // Something is terribly wrong APSI_LOG_ERROR("Critical error reading data"); @@ -164,6 +146,54 @@ auto ApsiCsvReader::read() -> pair> { return {std::move(result), std::move(orig_items)}; } +auto ApsiCsvReader::read_batch() -> pair> { + int row_cnt = 0; + + DBData result; + vector orig_items; + std::shared_ptr batch; + + // Attempt to read the first RecordBatch + auto [keys, labels] = batch_provider_->ReadNextLabeledBatch(); + + if (keys.empty()) { + // Handle end of file + return {std::move(result), std::move(orig_items)}; + } + + row_cnt += keys.size(); + + if (!labels.empty()) { + result = LabeledData{}; + } else { + result = UnlabeledData{}; + } + + for (size_t i = 0; i < keys.size(); i++) { + orig_items.emplace_back(keys[i]); + + if (labels.empty()) { + get(result).emplace_back(keys[i]); + } else if (holds_alternative(result)) { + get(result).emplace_back( + apsi::Item(keys[i]), apsi::Label(labels[i].begin(), labels[i].end())); + } else { + // Something is terribly wrong + APSI_LOG_ERROR("Critical error reading data"); + throw runtime_error("variant is in bad state"); + } + } + + YACL_ENFORCE(row_cnt != 0, "empty file : {}", file_name_); + YACL_ENFORCE(orig_items.size() == std::unordered_set( + orig_items.begin(), orig_items.end()) + .size(), + "source file {} has duplicated keys", file_name_); + SPDLOG_INFO("Read csv file {}, batch row cnt is {}", file_name_, row_cnt); + + return {std::move(result), std::move(orig_items)}; +} + void ApsiCsvReader::bucketize(size_t bucket_cnt, const std::string& bucket_folder) { if (!std::filesystem::exists(bucket_folder)) { diff --git a/psi/apsi_wrapper/utils/csv_reader.h b/psi/apsi_wrapper/utils/csv_reader.h index 89ab1c25..94cb725f 100644 --- a/psi/apsi_wrapper/utils/csv_reader.h +++ b/psi/apsi_wrapper/utils/csv_reader.h @@ -35,6 +35,7 @@ #include "arrow/io/api.h" #include "psi/apsi_wrapper/utils/common.h" +#include "psi/utils/arrow_csv_batch_provider.h" #include "psi/utils/multiplex_disk_cache.h" namespace psi::apsi_wrapper { @@ -44,10 +45,14 @@ Simple CSV file parser */ class ApsiCsvReader { public: - explicit ApsiCsvReader(const std::string& file_name); + explicit ApsiCsvReader(const std::string& file_name, + size_t batch_size = 1 << 20); std::pair> read(); + auto read_batch() -> std::pair>; + + // TODO(huocun): delete these two function void bucketize(size_t bucket_cnt, const std::string& bucket_folder); void GroupBucketize(size_t bucket_cnt, const std::string& bucket_folder, @@ -58,6 +63,10 @@ class ApsiCsvReader { private: std::string file_name_; + size_t batch_size_ = 1; + + std::shared_ptr batch_provider_; + std::shared_ptr reader_; std::vector> arrays_; diff --git a/psi/apsi_wrapper/utils/group_db.cc b/psi/apsi_wrapper/utils/group_db.cc index 7e291943..73da2f02 100644 --- a/psi/apsi_wrapper/utils/group_db.cc +++ b/psi/apsi_wrapper/utils/group_db.cc @@ -80,6 +80,39 @@ std::string PidFileName(pid_t pid) { fmt::format("apsi_process_{}", pid); } +void GroupBucketize(std::shared_ptr provider, + size_t num_buckets, size_t group_cnt, + MultiplexDiskCache& disk_cache) { + std::vector> bucket_group_vec; + disk_cache.CreateOutputStreams(group_cnt, &bucket_group_vec); + for (auto& out : bucket_group_vec) { + if (provider->HasLabel()) { + out->Write("bucket_id,key,value\n"); + } else { + out->Write("bucket_id,key\n"); + } + } + + auto batch = provider->ReadNextBatch(); + auto per_group_bucket = (num_buckets + group_cnt - 1) / group_cnt; + + while (!batch.keys.empty()) { + for (size_t i = 0; i < batch.keys.size(); ++i) { + auto key_hash = std::hash()(batch.keys[i]); + auto bucket_id = key_hash % num_buckets; + auto group_id = bucket_id / per_group_bucket; + auto& out = bucket_group_vec[group_id]; + if (provider->HasLabel()) { + out->Write(fmt::format("{},\"{}\",\"{}\"\n", bucket_id, batch.keys[i], + batch.labels[i])); + } else { + out->Write(fmt::format("{},\"{}\"\n", bucket_id, batch.keys[i])); + } + } + batch = provider->ReadNextBatch(); + } +} + } // namespace // Based on testing, we found that multi-process processing is more @@ -415,12 +448,28 @@ GroupDB::GroupDB(const std::string& db_path) apsi::PSIParams::Load(status_.params_file_content())); } +GroupDB::GroupDB(std::shared_ptr provider, + const std::string& db_path, std::size_t group_cnt, + size_t num_buckets, uint32_t nonce_byte_count, + const std::string& params_file, bool compress) + : GroupDB(db_path, group_cnt, num_buckets, nonce_byte_count, params_file, + compress) { + provider_ = provider; +} + GroupDB::GroupDB(const std::string& source_file, const std::string& db_path, std::size_t group_cnt, size_t num_buckets, uint32_t nonce_byte_count, const std::string& params_file, bool compress) - : source_file_(source_file), - db_path_(db_path), + : GroupDB(db_path, group_cnt, num_buckets, nonce_byte_count, params_file, + compress) { + source_file_ = source_file; +} + +GroupDB::GroupDB(const std::string& db_path, std::size_t group_cnt, + size_t num_buckets, uint32_t nonce_byte_count, + const std::string& params_file, bool compress) + : db_path_(db_path), group_cnt_(group_cnt), num_buckets_(num_buckets), nonce_byte_count_(nonce_byte_count), @@ -483,8 +532,12 @@ void GroupDB::DivideGroup() { std::filesystem::create_directories(db_path_); } - ApsiCsvReader reader(source_file_); - reader.GroupBucketize(num_buckets_, db_path_, group_cnt_, disk_cache_); + if (source_file_.empty()) { + GroupBucketize(provider_, num_buckets_, group_cnt_, disk_cache_); + } else { + ApsiCsvReader reader(source_file_); + reader.GroupBucketize(num_buckets_, db_path_, group_cnt_, disk_cache_); + } status_.set_state(GroupDBState::GROUP_DB_STATE_BUCKETED); SaveStatus(status_file_path_, status_); diff --git a/psi/apsi_wrapper/utils/group_db.h b/psi/apsi_wrapper/utils/group_db.h index d7602c67..55a602b3 100644 --- a/psi/apsi_wrapper/utils/group_db.h +++ b/psi/apsi_wrapper/utils/group_db.h @@ -25,6 +25,7 @@ #include #include "psi/apsi_wrapper/utils/sender_db.h" +#include "psi/kwpir/common/input_provider.h" #include "psi/apsi_wrapper/utils/group_db_status.pb.h" @@ -76,6 +77,11 @@ class GroupDB { size_t cnt; }; + GroupDB(std::shared_ptr provider, + const std::string& db_path, std::size_t group_cnt, size_t num_buckets, + uint32_t nonce_byte_count = 16, const std::string& params_file = "", + bool compress = false); + GroupDB(const std::string& source_file, const std::string& db_path, std::size_t group_cnt, size_t num_buckets, uint32_t nonce_byte_count = 16, const std::string& params_file = "", @@ -108,9 +114,15 @@ class GroupDB { ~GroupDB(); private: + GroupDB(const std::string& db_path, std::size_t group_cnt, size_t num_buckets, + uint32_t nonce_byte_count = 16, const std::string& params_file = "", + bool compress = false); + static inline const std::string status_file_name = "db.status"; std::string source_file_; + std::shared_ptr provider_; + std::string db_path_; size_t group_cnt_; size_t num_buckets_; diff --git a/psi/ecdh/ub_psi/ecdh_oprf_psi.cc b/psi/ecdh/ub_psi/ecdh_oprf_psi.cc index 5c67cfb3..de81f09e 100644 --- a/psi/ecdh/ub_psi/ecdh_oprf_psi.cc +++ b/psi/ecdh/ub_psi/ecdh_oprf_psi.cc @@ -517,9 +517,7 @@ size_t EcdhOprfPsiClient::SendBlindedItems( SPDLOG_INFO("Begin Send BlindedItems items"); while (true) { - std::vector items; - std::unordered_map dup_cnt; - std::tie(items, dup_cnt) = batch_provider->ReadNextBatchWithDupCnt(); + auto [items, dup_cnt] = batch_provider->ReadNextBatchWithDupCnt(); PsiDataBatch blinded_batch; blinded_batch.is_last_batch = items.empty(); diff --git a/psi/interface.cc b/psi/interface.cc index 521f4e7a..911ad8a2 100644 --- a/psi/interface.cc +++ b/psi/interface.cc @@ -219,6 +219,20 @@ void AbstractPsiParty::CheckSelfConfig() { YACL_ENFORCE_EQ(static_cast(keys_set.size()), config_.keys().size(), "Duplicated key is not allowed."); + if (!config_.protocol_config().broadcast_result() && + config_.advanced_join_type() != + v2::PsiConfig::ADVANCED_JOIN_TYPE_UNSPECIFIED) { + SPDLOG_WARN( + "broadcast_result turns off while advanced join is enabled. " + "broadcast_result is modified to true since intersection has to be " + "sent to both parties."); + + YACL_ENFORCE(!config_.output_config().path().empty(), + "You have to provide path of output."); + + config_.mutable_protocol_config()->set_broadcast_result(true); + } + if (!config_.skip_duplicates_check() && config_.advanced_join_type() != v2::PsiConfig::ADVANCED_JOIN_TYPE_UNSPECIFIED) { diff --git a/psi/kwpir/BUILD.bazel b/psi/kwpir/BUILD.bazel index 6225898f..cc9f37a2 100644 --- a/psi/kwpir/BUILD.bazel +++ b/psi/kwpir/BUILD.bazel @@ -19,20 +19,17 @@ package(default_visibility = ["//visibility:public"]) psi_cc_library( name = "kw_pir", srcs = ["kw_pir.cc"], - hdrs = [ - "index_pir.h", - "kw_pir.h", - ], + hdrs = ["kw_pir.h","index_pir.h"], linkopts = [ "-ldl", "-lm", ], deps = [ - "//psi/utils:cuckoo_index", "@yacl//yacl/base:byte_container_view", "@yacl//yacl/base:exception", "@yacl//yacl/base:int128", "@yacl//yacl/crypto/rand", + "//psi/utils:cuckoo_index", ], ) @@ -43,4 +40,4 @@ psi_cc_test( ":kw_pir", "//psi/sealpir:seal_pir", ], -) +) \ No newline at end of file diff --git a/psi/kwpir/client/BUILD.bazel b/psi/kwpir/client/BUILD.bazel new file mode 100644 index 00000000..a2a6dffe --- /dev/null +++ b/psi/kwpir/client/BUILD.bazel @@ -0,0 +1,31 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_library( + name = "kw_pir_client", + srcs = ["kw_pir_client.cc"], + hdrs = ["kw_pir_client.h"], + linkopts = [ + "-ldl", + "-lm", + ], + deps = [ + "@yacl//yacl/base:byte_container_view", + "@yacl//yacl/base:exception", + ], +) diff --git a/psi/kwpir/client/kw_pir_client.cc b/psi/kwpir/client/kw_pir_client.cc new file mode 100644 index 00000000..71ff3a85 --- /dev/null +++ b/psi/kwpir/client/kw_pir_client.cc @@ -0,0 +1,18 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/kwpir/client/kw_pir_client.h" + +namespace psi::kwpir { +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/client/kw_pir_client.h b/psi/kwpir/client/kw_pir_client.h new file mode 100644 index 00000000..0b793013 --- /dev/null +++ b/psi/kwpir/client/kw_pir_client.h @@ -0,0 +1,17 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace psi::kwpir {} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/common/BUILD.bazel b/psi/kwpir/common/BUILD.bazel new file mode 100644 index 00000000..949285a2 --- /dev/null +++ b/psi/kwpir/common/BUILD.bazel @@ -0,0 +1,27 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_library") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_library( + name = "input_provider", + srcs = ["input_provider.cc"], + hdrs = ["input_provider.h"], + deps = [ + "//psi/proto:common_cc_proto", + "//psi/utils:arrow_csv_batch_provider", + ], +) diff --git a/psi/kwpir/common/input_provider.cc b/psi/kwpir/common/input_provider.cc new file mode 100644 index 00000000..a085f84c --- /dev/null +++ b/psi/kwpir/common/input_provider.cc @@ -0,0 +1,46 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/kwpir/common/input_provider.h" + +#include +#include + +#include "yacl/base/exception.h" + +#include "psi/utils/arrow_helper.h" + +namespace psi::kwpir { + +InputProvider::InputProvider(InputConfig config, size_t batch_size) + : config_(config), batch_size_(batch_size) { + YACL_ENFORCE(config.file_type() == FileType::FILE_TYPE_CSV, + "Only support csv file now"); + + std::vector keys{config_.key_column_names().begin(), + config_.key_column_names().end()}; + std::vector values{config.value_column_names().begin(), + config.value_column_names().end()}; + has_label_ = !values.empty(); + csv_provider_ = std::make_shared( + config.file_name(), keys, batch_size_, values); +} + +InputProvider::Batch InputProvider::ReadNextBatch() { + Batch batch; + std::tie(batch.keys, batch.labels) = csv_provider_->ReadNextLabeledBatch(); + return batch; +} + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/common/input_provider.h b/psi/kwpir/common/input_provider.h new file mode 100644 index 00000000..490c9f9f --- /dev/null +++ b/psi/kwpir/common/input_provider.h @@ -0,0 +1,42 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "psi/utils/arrow_csv_batch_provider.h" + +#include "psi/proto/common.pb.h" + +namespace psi::kwpir { + +class InputProvider { + public: + InputProvider(InputConfig config, size_t batch_size = 1 << 20); + + struct Batch { + std::vector keys; + std::vector labels; + }; + Batch ReadNextBatch(); + + bool HasLabel() const { return has_label_; } + + protected: + InputConfig config_; + size_t batch_size_ = 1 << 20; + bool has_label_ = false; + std::shared_ptr csv_provider_; +}; + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/BUILD.bazel b/psi/kwpir/server/BUILD.bazel new file mode 100644 index 00000000..6321a2b2 --- /dev/null +++ b/psi/kwpir/server/BUILD.bazel @@ -0,0 +1,58 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:psi.bzl", "psi_cc_library") + +package(default_visibility = ["//visibility:public"]) + +psi_cc_library( + name = "kw_pir_server", + srcs = ["kw_pir_server.cc"], + hdrs = ["kw_pir_server.h"], + linkopts = [ + "-ldl", + "-lm", + ], + deps = [ + "//psi/kwpir/common:input_provider", + "//psi/proto:kw_pir_server_service_cc_proto", + "//psi/proto:pir_cc_proto", + "@com_github_brpc_brpc//:brpc", + "@yacl//yacl/base:exception", + "@yacl//yacl/utils:elapsed_timer", + ], +) + +psi_cc_library( + name = "apsi_kw_pir_server", + srcs = ["apsi_kw_pir_server.cc"], + hdrs = ["apsi_kw_pir_server.h"], + deps = [ + ":kw_pir_server", + "//psi/apsi_wrapper/api:sender", + "@com_github_brpc_brpc//:brpc", + "@yacl//yacl/crypto/rand", + ], +) + +psi_cc_library( + name = "kw_seal_pir_server", + srcs = ["kw_seal_pir_server.cc"], + hdrs = ["kw_seal_pir_server.h"], + deps = [ + ":kw_pir_server", + "@com_github_brpc_brpc//:brpc", + "@yacl//yacl/crypto/rand", + ], +) diff --git a/psi/kwpir/server/apsi_kw_pir_server.cc b/psi/kwpir/server/apsi_kw_pir_server.cc new file mode 100644 index 00000000..8f754548 --- /dev/null +++ b/psi/kwpir/server/apsi_kw_pir_server.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/kwpir/server/apsi_kw_pir_server.h" + +#include + +#include "yacl/base/exception.h" + +#include "psi/proto/apsi_wrapper.pb.h" + +namespace psi::kwpir { + +ApsiKeywordPirServer::ApsiKeywordPirServer(InputConfig input, + std::string db_path, + ApsiSenderConfig apsi_config) + : KeywordPirServer(input, db_path) { + apsi_option_.group_cnt = apsi_config.experimental_bucket_group_cnt(); + apsi_option_.num_buckets = apsi_config.experimental_bucket_cnt(); + apsi_option_.nonce_byte_count = apsi_config.nonce_byte_count(); + apsi_option_.compress = apsi_config.compress(); + apsi_option_.params_file = apsi_config.params_file(); + apsi_option_.group_cnt = apsi_config.experimental_bucket_group_cnt(); +} + +void ApsiKeywordPirServer::DoQuery(const KeywordPirServerRequest* request, + KeywordPirServerResponse* response) { + if (request->step() == "oprf") { + response->mutable_reply()->Add(sender_->RunOPRF(request->query(0))); + } else if (request->step() == "params") { + std::vector oprf{request->query().begin(), + request->query().end()}; + auto oprf_res = sender_->RunOPRF(oprf); + response->mutable_reply()->Assign(oprf_res.begin(), oprf_res.end()); + } else if (request->step() == "query") { + std::vector query{request->query().begin(), + request->query().end()}; + auto result = sender_->RunQuery(query); + response->mutable_reply()->Assign(result.begin(), result.end()); + } else { + YACL_THROW("unknown step {}", request->step()); + } +} + +void ApsiKeywordPirServer::DoLoadDataBase(const std::string& db_path) { + sender_ = std::make_shared(db_path); +} + +void ApsiKeywordPirServer::DoGenerateDataBase( + std::shared_ptr input, const std::string& db_path) { + apsi_option_.provider = input; + apsi_option_.db_path = db_path; + sender_ = std::make_shared(apsi_option_); +} + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/apsi_kw_pir_server.h b/psi/kwpir/server/apsi_kw_pir_server.h new file mode 100644 index 00000000..61649c82 --- /dev/null +++ b/psi/kwpir/server/apsi_kw_pir_server.h @@ -0,0 +1,42 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "psi/apsi_wrapper/api/sender.h" +#include "psi/kwpir/server/kw_pir_server.h" + +#include "psi/proto/apsi_wrapper.pb.h" + +namespace psi::kwpir { + +class ApsiKeywordPirServer : public KeywordPirServer { + public: + ApsiKeywordPirServer(InputConfig input, std::string db_path, + ApsiSenderConfig apsi_config); + + void DoQuery(const KeywordPirServerRequest* request, + KeywordPirServerResponse* response) override; + + void DoLoadDataBase(const std::string& db_path) override; + + void DoGenerateDataBase(std::shared_ptr input, + const std::string& db_path) override; + + protected: + psi::apsi_wrapper::api::Sender::KwPirOption apsi_option_; + std::shared_ptr sender_; +}; + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_pir_server.cc b/psi/kwpir/server/kw_pir_server.cc new file mode 100644 index 00000000..5c8d51a6 --- /dev/null +++ b/psi/kwpir/server/kw_pir_server.cc @@ -0,0 +1,64 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/kwpir/server/kw_pir_server.h" + +#include + +#include "yacl/base/exception.h" +#include "yacl/utils/elapsed_timer.h" + +namespace psi::kwpir { + +void KeywordPirServer::SetUp() { + yacl::ElapsedTimer timer; + + if (input_.file_name().empty()) { + DoLoadDataBase(db_path_); + } else { + auto provider = std::make_shared(input_); + DoGenerateDataBase(provider, db_path_); + } + + SPDLOG_INFO("SetUp cost: {} ms", timer.CountMs()); + is_ready_ = true; +} + +void KeywordPirServer::Query(::google::protobuf::RpcController* /*cntl_base*/, + const KeywordPirServerRequest* request, + KeywordPirServerResponse* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + + if (!is_ready_) { + response->mutable_status()->set_err_code(ErrorCode::NOT_READY); + response->mutable_status()->set_msg("Server is not ready"); + } + + SPDLOG_INFO("request step: {}", request->step()); + try { + DoQuery(request, response); + } catch (yacl::Exception& e) { + response->mutable_status()->set_err_code(ErrorCode::LOGIC_ERROR); + response->mutable_status()->set_msg(e.what()); + } catch (const std::exception& e) { + response->mutable_status()->set_err_code(ErrorCode::UNEXPECTED_ERROR); + response->mutable_status()->set_msg(e.what()); + return; + } + + response->mutable_status()->set_err_code(ErrorCode::OK); +} + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_pir_server.h b/psi/kwpir/server/kw_pir_server.h new file mode 100644 index 00000000..5ce741f4 --- /dev/null +++ b/psi/kwpir/server/kw_pir_server.h @@ -0,0 +1,59 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "brpc/server.h" +#include "spdlog/spdlog.h" + +#include "psi/kwpir/common/input_provider.h" + +#include "psi/proto/kw_pir_server_service.pb.h" + +namespace psi::kwpir { + +class KeywordPirServer : public KeywordPirServerService { + public: + KeywordPirServer(InputConfig input, std::string db_path) + : input_(std::move(input)), db_path_(std::move(db_path)) {} + + ~KeywordPirServer() override = default; + + void Query(::google::protobuf::RpcController* /*cntl_base*/, + const KeywordPirServerRequest* request, + KeywordPirServerResponse* response, + ::google::protobuf::Closure* done) override; + + void SetUp(); + + protected: + virtual void DoQuery(const KeywordPirServerRequest* request, + KeywordPirServerResponse* response) = 0; + + virtual void DoGenerateDataBase(std::shared_ptr input, + const std::string& db_path) = 0; + + virtual void DoLoadDataBase(const std::string& db_path) = 0; + + protected: + InputConfig input_; + std::string db_path_; + + private: + bool is_ready_ = false; +}; + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_seal_pir_server.cc b/psi/kwpir/server/kw_seal_pir_server.cc new file mode 100644 index 00000000..a9cb5e50 --- /dev/null +++ b/psi/kwpir/server/kw_seal_pir_server.cc @@ -0,0 +1,27 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "psi/kwpir/server/kw_seal_pir_server.h" + +#include "yacl/base/exception.h" + +namespace psi::kwpir { + +void KeywordSealPirServer::DoQuery(const KeywordPirServerRequest* request, + KeywordPirServerResponse* response) { + (void)request; + (void)response; +} + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_seal_pir_server.h b/psi/kwpir/server/kw_seal_pir_server.h new file mode 100644 index 00000000..fa20da24 --- /dev/null +++ b/psi/kwpir/server/kw_seal_pir_server.h @@ -0,0 +1,31 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "psi/kwpir/server/kw_pir_server.h" + +namespace psi::kwpir { + +class KeywordSealPirServer : public KeywordPirServer { + public: + KeywordSealPirServer() = default; + + void DoQuery(const KeywordPirServerRequest* request, + KeywordPirServerResponse* response) override; + + protected: +}; + +} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/launch.cc b/psi/launch.cc index a56265d1..3b9f63c2 100644 --- a/psi/launch.cc +++ b/psi/launch.cc @@ -184,6 +184,9 @@ PirResultReport RunPir(const ApsiReceiverConfig& apsi_receiver_config, apsi_receiver_config.experimental_enable_bucketize(); options.experimental_bucket_cnt = apsi_receiver_config.experimental_bucket_cnt(); + options.query_batch_size = apsi_receiver_config.query_batch_size() + ? apsi_receiver_config.query_batch_size() + : 1; int* match_cnt = new int(0); diff --git a/psi/proto/BUILD.bazel b/psi/proto/BUILD.bazel index f0b0d418..24d7a2c4 100644 --- a/psi/proto/BUILD.bazel +++ b/psi/proto/BUILD.bazel @@ -16,6 +16,52 @@ load("@rules_proto//proto:defs.bzl", "proto_library") package(default_visibility = ["//visibility:public"]) +proto_library( + name = "common_proto", + srcs = ["common.proto"], +) + +cc_proto_library( + name = "common_cc_proto", + deps = [":common_proto"], +) + +proto_library( + name = "kw_pir_client_service_proto", + srcs = ["kw_pir_client_service.proto"], + deps = [ + ":common_proto", + ], +) + +cc_proto_library( + name = "kw_pir_client_service_cc_proto", + deps = [":kw_pir_client_service_proto"], +) + +proto_library( + name = "kw_pir_server_service_proto", + srcs = ["kw_pir_server_service.proto"], + deps = [ + ":common_proto", + ], +) + +cc_proto_library( + name = "kw_pir_server_service_cc_proto", + deps = [":kw_pir_server_service_proto"], +) + +proto_library( + name = "kw_seal_pir_proto", + srcs = ["kw_seal_pir.proto"], +) + +cc_proto_library( + name = "kw_seal_pir_cc_proto", + deps = [":kw_seal_pir_proto"], +) + proto_library( name = "psi_proto", srcs = ["psi.proto"], @@ -26,14 +72,31 @@ cc_proto_library( deps = [":psi_proto"], ) +proto_library( + name = "apsi_wrapper_proto", + srcs = ["apsi_wrapper.proto"], +) + +cc_proto_library( + name = "apsi_wrapper_cc_proto", + deps = [":apsi_wrapper_proto"], +) + proto_library( name = "pir_proto", srcs = ["pir.proto"], + deps = [ + "kw_seal_pir_proto", + ":apsi_wrapper_proto", + ":common_proto", + ], ) cc_proto_library( name = "pir_cc_proto", - deps = [":pir_proto"], + deps = [ + ":pir_proto", + ], ) proto_library( @@ -53,6 +116,7 @@ proto_library( name = "entry_proto", srcs = ["entry.proto"], deps = [ + ":apsi_wrapper_proto", ":pir_proto", ":psi_proto", ":psi_v2_proto", diff --git a/psi/proto/apsi_wrapper.proto b/psi/proto/apsi_wrapper.proto new file mode 100644 index 00000000..b57e6013 --- /dev/null +++ b/psi/proto/apsi_wrapper.proto @@ -0,0 +1,134 @@ +// +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package psi; + +// NOTE(junfeng): We provide a config identical to original APSI CLI. +// Please check +// https://github.com/microsoft/APSI?tab=readme-ov-file#command-line-interface-cli +// for details. +message ApsiSenderConfig { + // Number of threads to use + uint32 threads = 1; + + // Log file path. For APSI only. + string log_file = 2; + + // Do not write output to console. For APSI only. + bool silent = 3; + + // One of 'all', 'debug', 'info' (default), 'warning', 'error', 'off'. For + // APSI only. + string log_level = 4; + + // Path to a CSV file describing the sender's dataset (an item-label pair on + // each row) or a file containing a serialized SenderDB; the CLI will first + // attempt to load the data as a serialized SenderDB, and – upon failure – + // will proceed to attempt to read it as a CSV file + // For CSV File: + // 1. the first col is processed as item while the second col as label. OTHER + // COLS ARE IGNORED. + // 2. NO HEADERS ARE ALLOWED. + string db_file = 5; + + // Path to a JSON file describing the parameters to be used by the sender. + // Not required if db_file points to a serialized SenderDB. + string params_file = 6; + + // Save the SenderDB in the given file. + // Required if gen_db_only is set true. + // Use experimental_bucket_folder instead if you turn + // experimental_enable_bucketize on. + string sdb_out_file = 7; + + // Number of bytes used for the nonce in labeled mode (default is 16) + uint32 nonce_byte_count = 8; + + // Whether to compress the SenderDB in memory; this will make the memory + // footprint smaller at the cost of increased computation. + bool compress = 9; + + // Whether to save sender db only. + bool save_db_only = 10; + + // [experimental] Whether to split data in buckets and Each bucket would be a + // seperate SenderDB. If set, experimental_bucket_folder must be a valid + // folder. + bool experimental_enable_bucketize = 13; + + // [experimental] The number of bucket to fit data. + uint32 experimental_bucket_cnt = 14; + + // [experimental] Folder to save bucketized small csv files and db files. + string experimental_bucket_folder = 15; + + // [experimental] The number of processes to use for generating db. + int32 experimental_db_generating_process_num = 16; + + // Source file used to genenerate sender db. + // Currently only support csv file. + string source_file = 17; + + // [experimental] The number of group of bucket, each group has a db_file, + // default 1024. + int32 experimental_bucket_group_cnt = 18; +} + +message ApsiReceiverConfig { + // Number of threads to use + uint32 threads = 1; + + // Log file path. For APSI only. + string log_file = 2; + + // Do not write output to console. For APSI only. + bool silent = 3; + + // One of 'all', 'debug', 'info' (default), 'warning', 'error', 'off'. For + // APSI only. + string log_level = 4; + + // Path to a text file containing query data (one per line). + // Header is not needed. + string query_file = 5; + + // Path to a file where intersection result will be written. + string output_file = 6; + + // Path to a JSON file describing the parameters to be used by the sender. + // If not set, receiver will ask sender, which results in additional + // communication. + string params_file = 7; + + // Must be same as sender config. + bool experimental_enable_bucketize = 8; + + // Must be same as sender config. + uint32 experimental_bucket_cnt = 9; + + // The number of query in a batch. default 1. + uint32 query_batch_size = 10; +} + +// The report of pir task. +message PirResultReport { + int64 match_cnt = 1; +} diff --git a/psi/proto/common.proto b/psi/proto/common.proto new file mode 100644 index 00000000..4b8ef298 --- /dev/null +++ b/psi/proto/common.proto @@ -0,0 +1,79 @@ +// +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package psi; + +option cc_generic_services = true; + +enum FileType { + FILE_TYPE_UNKNOWN = 0; + FILE_TYPE_CSV = 1; +} + +enum ErrorCode { + // Placeholder for proto3 default value, do not use it + UNKNOWN = 0; + + // 001-099 for general code + + OK = 1; + UNEXPECTED_ERROR = 2; + INVALID_ARGUMENT = 3; + NETWORK_ERROR = 4; + // Some requested entity (e.g., file or directory) was not found. + NOT_FOUND = 5; + NOT_IMPLEMENTED = 6; + LOGIC_ERROR = 7; + SERIALIZE_FAILED = 8; + DESERIALIZE_FAILED = 9; + IO_ERROR = 10; + NOT_READY = 11; +} + +message Header { + map data = 1; +} + +message Status { + ErrorCode err_code = 1; + string msg = 2; +} + +message InputConfig { + string file_name = 1; + // if not set, default is csv + FileType file_type = 2; + + repeated string key_column_names = 3; + repeated string value_column_names = 4; +} + +message OutputConfig { + string file_name = 1; + // if not set, default is csv + FileType file_type = 2; +} + +enum KeyWordPirBackend { + BACKEND_UNKNOWN = 0; + BACKEND_APSI = 1; + BACKEND_SEAL_PIR = 2; +} \ No newline at end of file diff --git a/psi/proto/entry.proto b/psi/proto/entry.proto index 75c37d4a..dee4aecc 100644 --- a/psi/proto/entry.proto +++ b/psi/proto/entry.proto @@ -16,8 +16,8 @@ syntax = "proto3"; import "psi/proto/psi.proto"; -import "psi/proto/pir.proto"; import "psi/proto/psi_v2.proto"; +import "psi/proto/apsi_wrapper.proto"; import "yacl/link/link.proto"; package psi; diff --git a/psi/proto/kw_pir_client_service.proto b/psi/proto/kw_pir_client_service.proto new file mode 100644 index 00000000..4b6374de --- /dev/null +++ b/psi/proto/kw_pir_client_service.proto @@ -0,0 +1,48 @@ +// +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +import "psi/proto/common.proto"; + +package psi; + +option cc_generic_services = true; + +// User -> Pir client, used when client is in service mode. +service PirClientService { + rpc Query(KeywordPirClientRequest) returns (KeywordPirClientResponse); +} + +message KeywordPirClientRequest { + Header header = 1; + + reserved 2 to 9; + + repeated string keyword = 10; +} + +message KeywordPirClientResponse { + Header header = 1; + Status status = 2; + + reserved 3 to 9; + + repeated string result = 10; +} diff --git a/psi/proto/kw_pir_server_service.proto b/psi/proto/kw_pir_server_service.proto new file mode 100644 index 00000000..8e54da4e --- /dev/null +++ b/psi/proto/kw_pir_server_service.proto @@ -0,0 +1,49 @@ +// +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +import "psi/proto/common.proto"; + +package psi; + +option cc_generic_services = true; + +// Pir client -> [Pir server] +service KeywordPirServerService { + rpc Query(KeywordPirServerRequest) returns (KeywordPirServerResponse); +} + +message KeywordPirServerRequest { + Header header = 1; + + reserved 2 to 9; + + string step = 10; + repeated bytes query = 11; +} + +message KeywordPirServerResponse { + Header header = 1; + Status status = 2; + + reserved 3 to 9; + + repeated bytes reply = 10; +} diff --git a/psi/proto/kw_seal_pir.proto b/psi/proto/kw_seal_pir.proto new file mode 100644 index 00000000..f97a7b80 --- /dev/null +++ b/psi/proto/kw_seal_pir.proto @@ -0,0 +1,28 @@ +// +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +syntax = "proto3"; + +package psi; + +option cc_generic_services = true; + +message SealPirClientConfig {} + +message SealPirServerConfig {} diff --git a/psi/proto/pir.proto b/psi/proto/pir.proto index fb3b8873..a81e554a 100644 --- a/psi/proto/pir.proto +++ b/psi/proto/pir.proto @@ -19,113 +19,76 @@ syntax = "proto3"; +import "psi/proto/common.proto"; +import "psi/proto/apsi_wrapper.proto"; +import "psi/proto/kw_seal_pir.proto"; + package psi; -// NOTE(junfeng): We provide a config identical to original APSI CLI. -// Please check -// https://github.com/microsoft/APSI?tab=readme-ov-file#command-line-interface-cli -// for details. -message ApsiSenderConfig { - // Number of threads to use - uint32 threads = 1; - - // Log file path. For APSI only. - string log_file = 2; - - // Do not write output to console. For APSI only. - bool silent = 3; - - // One of 'all', 'debug', 'info' (default), 'warning', 'error', 'off'. For - // APSI only. - string log_level = 4; - - // Path to a CSV file describing the sender's dataset (an item-label pair on - // each row) or a file containing a serialized SenderDB; the CLI will first - // attempt to load the data as a serialized SenderDB, and – upon failure – - // will proceed to attempt to read it as a CSV file - // For CSV File: - // 1. the first col is processed as item while the second col as label. OTHER - // COLS ARE IGNORED. - // 2. NO HEADERS ARE ALLOWED. - string db_file = 5; - - // Path to a JSON file describing the parameters to be used by the sender. - // Not required if db_file points to a serialized SenderDB. - string params_file = 6; - - // Save the SenderDB in the given file. - // Required if gen_db_only is set true. - // Use experimental_bucket_folder instead if you turn - // experimental_enable_bucketize on. - string sdb_out_file = 7; - - // Number of bytes used for the nonce in labeled mode (default is 16) - uint32 nonce_byte_count = 8; - - // Whether to compress the SenderDB in memory; this will make the memory - // footprint smaller at the cost of increased computation. - bool compress = 9; - - // Whether to save sender db only. - bool save_db_only = 10; - - // [experimental] Whether to split data in buckets and Each bucket would be a - // seperate SenderDB. If set, experimental_bucket_folder must be a valid - // folder. - bool experimental_enable_bucketize = 13; - - // [experimental] The number of bucket to fit data. - uint32 experimental_bucket_cnt = 14; - - // [experimental] Folder to save bucketized small csv files and db files. - string experimental_bucket_folder = 15; - - // [experimental] The number of processes to use for generating db. - int32 experimental_db_generating_process_num = 16; - - // Source file used to genenerate sender db. - // Currently only support csv file. - string source_file = 17; - - // [experimental] The number of group of bucket, each group has a db_file, - // default 1024. - int32 experimental_bucket_group_cnt = 18; +option cc_generic_services = true; + +message KeywordPirServerBackendConfig { + KeyWordPirBackend backend_type = 1; + + oneof backend { + ApsiSenderConfig apsi = 2; + SealPirServerConfig seal_pir = 3; + } } -message ApsiReceiverConfig { - // Number of threads to use - uint32 threads = 1; +message KeywordPirServerConfig { + string host = 1; - // Log file path. For APSI only. - string log_file = 2; + InputConfig input = 2; - // Do not write output to console. For APSI only. - bool silent = 3; + string db_path = 3; - // One of 'all', 'debug', 'info' (default), 'warning', 'error', 'off'. For - // APSI only. - string log_level = 4; + uint32 server_port = 4; - // Path to a text file containing query data (one per line). - // Header is not needed. - string query_file = 5; + reserved 5, 9; - // Path to a file where intersection result will be written. - string output_file = 6; + KeywordPirServerBackendConfig backend = 10; +} - // Path to a JSON file describing the parameters to be used by the sender. - // If not set, receiver will ask sender, which results in additional - // communication. - string params_file = 7; +message KeywordPirClientBackendConfig { + KeyWordPirBackend backend_type = 1; - // Must be same as sender config. - bool experimental_enable_bucketize = 8; + oneof backend { + ApsiReceiverConfig apsi = 2; + SealPirClientConfig seal_pir = 3; + } +} - // Must be same as sender config. - uint32 experimental_bucket_cnt = 9; +enum KeywordPirClientMode { + CLIENT_MODE_UNKNOWN = 0; + CLIENT_MODE_FILE = 1; + CLIENT_MODE_SERVICE = 2; } -// The report of pir task. -message PirResultReport { - int64 match_cnt = 1; +message KeywordPirClientFileModeConfig { + InputConfig input = 1; + + OutputConfig output = 2; } + +message KeywordPirClientServiceModeConfig { + string host = 1; + + uint32 port = 2; +} + +message KeywordPirClientConfig { + string server_address = 1; + + KeywordPirClientMode mode = 2; + + oneof config { + KeywordPirClientFileModeConfig file_config = 3; + + KeywordPirClientServiceModeConfig service_config = 4; + } + + reserved 5 to 9; + + KeywordPirClientBackendConfig backend = 10; +} \ No newline at end of file diff --git a/psi/rr22/rr22_oprf.h b/psi/rr22/rr22_oprf.h index db436dcd..765c82c0 100644 --- a/psi/rr22/rr22_oprf.h +++ b/psi/rr22/rr22_oprf.h @@ -189,6 +189,7 @@ class Rr22OprfReceiver : public Rr22Oprf { const absl::Span& inputs); private: + size_t init_size_ = 0; size_t num_threads_ = 0; okvs::Baxos baxos_; okvs::Paxos paxos_; diff --git a/psi/rr22/rr22_psi.h b/psi/rr22/rr22_psi.h index c5384721..890fde05 100644 --- a/psi/rr22/rr22_psi.h +++ b/psi/rr22/rr22_psi.h @@ -91,8 +91,6 @@ class BucketRr22Core { broadcast_result_(broadcast_result), bucket_idx_(bucket_idx) {} - virtual ~BucketRr22Core() = default; - virtual void Prepare(const std::shared_ptr& lctx) = 0; virtual void RunOprf(const std::shared_ptr& lctx) = 0; virtual void GetIntersection( diff --git a/psi/sealpir/BUILD.bazel b/psi/sealpir/BUILD.bazel index a48675ec..5193b9a4 100644 --- a/psi/sealpir/BUILD.bazel +++ b/psi/sealpir/BUILD.bazel @@ -40,8 +40,8 @@ psi_cc_library( ], deps = [ ":seal_pir_utils", - "//psi/kwpir:kw_pir", "//psi/sealpir:serializable_cc_proto", + "//psi/kwpir:kw_pir", "@com_github_microsoft_seal//:seal", "@com_github_openssl_openssl//:openssl", "@yacl//yacl/base:byte_container_view", @@ -68,4 +68,4 @@ psi_cc_test( ":seal_pir", "@com_github_microsoft_seal//:seal", ], -) +) \ No newline at end of file diff --git a/psi/utils/BUILD.bazel b/psi/utils/BUILD.bazel index c9ca51f3..9b74f521 100644 --- a/psi/utils/BUILD.bazel +++ b/psi/utils/BUILD.bazel @@ -331,6 +331,7 @@ psi_cc_library( srcs = ["arrow_csv_batch_provider.cc"], hdrs = ["arrow_csv_batch_provider.h"], deps = [ + ":arrow_helper", ":batch_provider", ":key", "@org_apache_arrow//:arrow", diff --git a/psi/utils/arrow_csv_batch_provider.cc b/psi/utils/arrow_csv_batch_provider.cc index 93a5ba4d..96e00ccb 100644 --- a/psi/utils/arrow_csv_batch_provider.cc +++ b/psi/utils/arrow_csv_batch_provider.cc @@ -16,6 +16,7 @@ #include #include +#include #include "arrow/array.h" #include "arrow/compute/api.h" @@ -23,6 +24,7 @@ #include "spdlog/spdlog.h" #include "yacl/base/exception.h" +#include "psi/utils/arrow_helper.h" #include "psi/utils/key.h" namespace psi { @@ -100,7 +102,7 @@ void ArrowCsvBatchProvider::ReadNextBatch( read_keys->emplace_back(KeysJoin(values)); } - if (read_labels) { + if (!labels_.empty()) { std::vector values; for (size_t i = keys_.size(); i < arrays_.size(); i++) { values.emplace_back(arrays_[i]->Value(idx_in_batch_)); @@ -119,24 +121,31 @@ void ArrowCsvBatchProvider::Init() { YACL_ENFORCE(!keys_.empty(), "You must provide keys."); - arrow::io::IOContext io_context = arrow::io::default_io_context(); - infile_ = - arrow::io::ReadableFile::Open(file_path_, arrow::default_memory_pool()) - .ValueOrDie(); + auto columns = GetCsvColumnsNames(file_path_); + std::unordered_set columns_set(columns.begin(), columns.end()); auto read_options = arrow::csv::ReadOptions::Defaults(); auto parse_options = arrow::csv::ParseOptions::Defaults(); auto convert_options = arrow::csv::ConvertOptions::Defaults(); for (const auto& key : keys_) { + YACL_ENFORCE(columns_set.find(key) != columns_set.end(), + "Key column {} not found in csv file.", key); convert_options.column_types[key] = arrow::utf8(); } for (const auto& label : labels_) { + YACL_ENFORCE(columns_set.find(label) != columns_set.end(), + "label column {} not found in csv file.", label); convert_options.column_types[label] = arrow::utf8(); } + convert_options.include_columns = keys_; convert_options.include_columns.insert(convert_options.include_columns.end(), labels_.begin(), labels_.end()); + auto io_context = arrow::io::default_io_context(); + infile_ = + arrow::io::ReadableFile::Open(file_path_, arrow::default_memory_pool()) + .ValueOrDie(); reader_ = arrow::csv::StreamingReader::Make(io_context, infile_, read_options, parse_options, convert_options) diff --git a/psi/utils/arrow_helper.h b/psi/utils/arrow_helper.h index 0348b681..1cebb30d 100644 --- a/psi/utils/arrow_helper.h +++ b/psi/utils/arrow_helper.h @@ -20,6 +20,8 @@ #include "arrow/csv/api.h" +#include "psi/utils/arrow_helper.h" + namespace psi { #define PSI_ARROW_GET_RESULT(value, maker) \ diff --git a/psi/utils/csv_header_parser_test.cc b/psi/utils/csv_header_parser_test.cc index 181509a1..8631bcf5 100644 --- a/psi/utils/csv_header_parser_test.cc +++ b/psi/utils/csv_header_parser_test.cc @@ -51,7 +51,9 @@ TEST(CsvHeaderParserTest, Works) { std::vector{"y1", "id2", "id2", "id1"}, 1), (std::vector{3, 2, 2, 1})); - { std::filesystem::remove(csv_path); } + { + std::filesystem::remove(csv_path); + } } } // namespace psi diff --git a/psi/utils/join_processor.cc b/psi/utils/join_processor.cc index bcc65078..d516ab83 100644 --- a/psi/utils/join_processor.cc +++ b/psi/utils/join_processor.cc @@ -121,17 +121,12 @@ JoinProcessor::JoinProcessor(const v2::UbPsiConfig& ub_psi_config, v2::UbPsiConfig::MODE_FULL, }; - bool gen_output = - (role_ == v2::ROLE_SERVER && ub_psi_config.server_get_result()) || - (role_ == v2::ROLE_CLIENT && ub_psi_config.client_get_result()); - if (gen_output) { - if (gen_output_mode.find(ub_psi_config.mode()) != gen_output_mode.end()) { - YACL_ENFORCE( - ub_psi_config.output_config().type() == v2::IoType::IO_TYPE_FILE_CSV, - "unsupport output format {}", - v2::IoType_Name(ub_psi_config.input_config().type())); - output_path_ = ub_psi_config.output_config().path(); - } + if (gen_output_mode.find(ub_psi_config.mode()) != gen_output_mode.end()) { + YACL_ENFORCE( + ub_psi_config.output_config().type() == v2::IoType::IO_TYPE_FILE_CSV, + "unsupport output format {}", + v2::IoType_Name(ub_psi_config.input_config().type())); + output_path_ = ub_psi_config.output_config().path(); } if (!std::filesystem::exists(ub_psi_config.cache_path())) { @@ -270,7 +265,6 @@ std::shared_ptr JoinProcessor::GetUniqueKeysInfo() { KeyInfo::StatInfo JoinProcessor::DealResultIndex(IndexReader& index) { ResultDumper dumper(sorted_intersect_path_, sorted_except_path_); auto stat = GetUniqueKeysInfo()->ApplyPeerDupCnt(index, dumper); - dumper.Flush(); if (is_input_key_unique_ && align_output_) { if (!sorted_intersect_path_.empty()) { Table::MakeFromCsv(sorted_intersect_path_) diff --git a/psi/utils/table_utils.cc b/psi/utils/table_utils.cc index 98058b45..213a5b77 100644 --- a/psi/utils/table_utils.cc +++ b/psi/utils/table_utils.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -481,15 +480,6 @@ void ResultDumper::Dump(const std::string& line, int64_t duplicate_cnt, } } -void ResultDumper::Flush() { - if (intersect_file_) { - intersect_file_->flush(); - } - if (except_file_) { - except_file_->flush(); - } -} - std::vector KeyInfo::SourceFileColumns() const { return table_->Columns(); } diff --git a/psi/utils/table_utils.h b/psi/utils/table_utils.h index f8f5734a..2217309d 100644 --- a/psi/utils/table_utils.h +++ b/psi/utils/table_utils.h @@ -129,8 +129,6 @@ struct ResultDumper { int64_t except_cnt() const { return except_cnt_; } int64_t intersect_cnt() const { return intersect_cnt_; } - void Flush(); - private: void Dump(const std::string& line, int64_t duplicate_cnt, std::shared_ptr& file, int64_t* total_dump_cnt); From cfe5b6dd22d0e62a5da371f5601f4968dcbd644a Mon Sep 17 00:00:00 2001 From: huocun Date: Mon, 18 Nov 2024 12:12:48 +0800 Subject: [PATCH 2/3] repo-sync-2024-11-18T12:12:41+0800 --- psi/kwpir/client/BUILD.bazel | 31 ------------ psi/kwpir/client/kw_pir_client.cc | 18 ------- psi/kwpir/client/kw_pir_client.h | 17 ------- psi/kwpir/server/BUILD.bazel | 58 ---------------------- psi/kwpir/server/apsi_kw_pir_server.cc | 67 -------------------------- psi/kwpir/server/apsi_kw_pir_server.h | 42 ---------------- psi/kwpir/server/kw_pir_server.cc | 64 ------------------------ psi/kwpir/server/kw_pir_server.h | 59 ----------------------- psi/kwpir/server/kw_seal_pir_server.cc | 27 ----------- psi/kwpir/server/kw_seal_pir_server.h | 31 ------------ 10 files changed, 414 deletions(-) delete mode 100644 psi/kwpir/client/BUILD.bazel delete mode 100644 psi/kwpir/client/kw_pir_client.cc delete mode 100644 psi/kwpir/client/kw_pir_client.h delete mode 100644 psi/kwpir/server/BUILD.bazel delete mode 100644 psi/kwpir/server/apsi_kw_pir_server.cc delete mode 100644 psi/kwpir/server/apsi_kw_pir_server.h delete mode 100644 psi/kwpir/server/kw_pir_server.cc delete mode 100644 psi/kwpir/server/kw_pir_server.h delete mode 100644 psi/kwpir/server/kw_seal_pir_server.cc delete mode 100644 psi/kwpir/server/kw_seal_pir_server.h diff --git a/psi/kwpir/client/BUILD.bazel b/psi/kwpir/client/BUILD.bazel deleted file mode 100644 index a2a6dffe..00000000 --- a/psi/kwpir/client/BUILD.bazel +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "kw_pir_client", - srcs = ["kw_pir_client.cc"], - hdrs = ["kw_pir_client.h"], - linkopts = [ - "-ldl", - "-lm", - ], - deps = [ - "@yacl//yacl/base:byte_container_view", - "@yacl//yacl/base:exception", - ], -) diff --git a/psi/kwpir/client/kw_pir_client.cc b/psi/kwpir/client/kw_pir_client.cc deleted file mode 100644 index 71ff3a85..00000000 --- a/psi/kwpir/client/kw_pir_client.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "psi/kwpir/client/kw_pir_client.h" - -namespace psi::kwpir { -} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/client/kw_pir_client.h b/psi/kwpir/client/kw_pir_client.h deleted file mode 100644 index 0b793013..00000000 --- a/psi/kwpir/client/kw_pir_client.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -namespace psi::kwpir {} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/BUILD.bazel b/psi/kwpir/server/BUILD.bazel deleted file mode 100644 index 6321a2b2..00000000 --- a/psi/kwpir/server/BUILD.bazel +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:psi.bzl", "psi_cc_library") - -package(default_visibility = ["//visibility:public"]) - -psi_cc_library( - name = "kw_pir_server", - srcs = ["kw_pir_server.cc"], - hdrs = ["kw_pir_server.h"], - linkopts = [ - "-ldl", - "-lm", - ], - deps = [ - "//psi/kwpir/common:input_provider", - "//psi/proto:kw_pir_server_service_cc_proto", - "//psi/proto:pir_cc_proto", - "@com_github_brpc_brpc//:brpc", - "@yacl//yacl/base:exception", - "@yacl//yacl/utils:elapsed_timer", - ], -) - -psi_cc_library( - name = "apsi_kw_pir_server", - srcs = ["apsi_kw_pir_server.cc"], - hdrs = ["apsi_kw_pir_server.h"], - deps = [ - ":kw_pir_server", - "//psi/apsi_wrapper/api:sender", - "@com_github_brpc_brpc//:brpc", - "@yacl//yacl/crypto/rand", - ], -) - -psi_cc_library( - name = "kw_seal_pir_server", - srcs = ["kw_seal_pir_server.cc"], - hdrs = ["kw_seal_pir_server.h"], - deps = [ - ":kw_pir_server", - "@com_github_brpc_brpc//:brpc", - "@yacl//yacl/crypto/rand", - ], -) diff --git a/psi/kwpir/server/apsi_kw_pir_server.cc b/psi/kwpir/server/apsi_kw_pir_server.cc deleted file mode 100644 index 8f754548..00000000 --- a/psi/kwpir/server/apsi_kw_pir_server.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "psi/kwpir/server/apsi_kw_pir_server.h" - -#include - -#include "yacl/base/exception.h" - -#include "psi/proto/apsi_wrapper.pb.h" - -namespace psi::kwpir { - -ApsiKeywordPirServer::ApsiKeywordPirServer(InputConfig input, - std::string db_path, - ApsiSenderConfig apsi_config) - : KeywordPirServer(input, db_path) { - apsi_option_.group_cnt = apsi_config.experimental_bucket_group_cnt(); - apsi_option_.num_buckets = apsi_config.experimental_bucket_cnt(); - apsi_option_.nonce_byte_count = apsi_config.nonce_byte_count(); - apsi_option_.compress = apsi_config.compress(); - apsi_option_.params_file = apsi_config.params_file(); - apsi_option_.group_cnt = apsi_config.experimental_bucket_group_cnt(); -} - -void ApsiKeywordPirServer::DoQuery(const KeywordPirServerRequest* request, - KeywordPirServerResponse* response) { - if (request->step() == "oprf") { - response->mutable_reply()->Add(sender_->RunOPRF(request->query(0))); - } else if (request->step() == "params") { - std::vector oprf{request->query().begin(), - request->query().end()}; - auto oprf_res = sender_->RunOPRF(oprf); - response->mutable_reply()->Assign(oprf_res.begin(), oprf_res.end()); - } else if (request->step() == "query") { - std::vector query{request->query().begin(), - request->query().end()}; - auto result = sender_->RunQuery(query); - response->mutable_reply()->Assign(result.begin(), result.end()); - } else { - YACL_THROW("unknown step {}", request->step()); - } -} - -void ApsiKeywordPirServer::DoLoadDataBase(const std::string& db_path) { - sender_ = std::make_shared(db_path); -} - -void ApsiKeywordPirServer::DoGenerateDataBase( - std::shared_ptr input, const std::string& db_path) { - apsi_option_.provider = input; - apsi_option_.db_path = db_path; - sender_ = std::make_shared(apsi_option_); -} - -} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/apsi_kw_pir_server.h b/psi/kwpir/server/apsi_kw_pir_server.h deleted file mode 100644 index 61649c82..00000000 --- a/psi/kwpir/server/apsi_kw_pir_server.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "psi/apsi_wrapper/api/sender.h" -#include "psi/kwpir/server/kw_pir_server.h" - -#include "psi/proto/apsi_wrapper.pb.h" - -namespace psi::kwpir { - -class ApsiKeywordPirServer : public KeywordPirServer { - public: - ApsiKeywordPirServer(InputConfig input, std::string db_path, - ApsiSenderConfig apsi_config); - - void DoQuery(const KeywordPirServerRequest* request, - KeywordPirServerResponse* response) override; - - void DoLoadDataBase(const std::string& db_path) override; - - void DoGenerateDataBase(std::shared_ptr input, - const std::string& db_path) override; - - protected: - psi::apsi_wrapper::api::Sender::KwPirOption apsi_option_; - std::shared_ptr sender_; -}; - -} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_pir_server.cc b/psi/kwpir/server/kw_pir_server.cc deleted file mode 100644 index 5c8d51a6..00000000 --- a/psi/kwpir/server/kw_pir_server.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "psi/kwpir/server/kw_pir_server.h" - -#include - -#include "yacl/base/exception.h" -#include "yacl/utils/elapsed_timer.h" - -namespace psi::kwpir { - -void KeywordPirServer::SetUp() { - yacl::ElapsedTimer timer; - - if (input_.file_name().empty()) { - DoLoadDataBase(db_path_); - } else { - auto provider = std::make_shared(input_); - DoGenerateDataBase(provider, db_path_); - } - - SPDLOG_INFO("SetUp cost: {} ms", timer.CountMs()); - is_ready_ = true; -} - -void KeywordPirServer::Query(::google::protobuf::RpcController* /*cntl_base*/, - const KeywordPirServerRequest* request, - KeywordPirServerResponse* response, - ::google::protobuf::Closure* done) { - brpc::ClosureGuard done_guard(done); - - if (!is_ready_) { - response->mutable_status()->set_err_code(ErrorCode::NOT_READY); - response->mutable_status()->set_msg("Server is not ready"); - } - - SPDLOG_INFO("request step: {}", request->step()); - try { - DoQuery(request, response); - } catch (yacl::Exception& e) { - response->mutable_status()->set_err_code(ErrorCode::LOGIC_ERROR); - response->mutable_status()->set_msg(e.what()); - } catch (const std::exception& e) { - response->mutable_status()->set_err_code(ErrorCode::UNEXPECTED_ERROR); - response->mutable_status()->set_msg(e.what()); - return; - } - - response->mutable_status()->set_err_code(ErrorCode::OK); -} - -} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_pir_server.h b/psi/kwpir/server/kw_pir_server.h deleted file mode 100644 index 5ce741f4..00000000 --- a/psi/kwpir/server/kw_pir_server.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "brpc/server.h" -#include "spdlog/spdlog.h" - -#include "psi/kwpir/common/input_provider.h" - -#include "psi/proto/kw_pir_server_service.pb.h" - -namespace psi::kwpir { - -class KeywordPirServer : public KeywordPirServerService { - public: - KeywordPirServer(InputConfig input, std::string db_path) - : input_(std::move(input)), db_path_(std::move(db_path)) {} - - ~KeywordPirServer() override = default; - - void Query(::google::protobuf::RpcController* /*cntl_base*/, - const KeywordPirServerRequest* request, - KeywordPirServerResponse* response, - ::google::protobuf::Closure* done) override; - - void SetUp(); - - protected: - virtual void DoQuery(const KeywordPirServerRequest* request, - KeywordPirServerResponse* response) = 0; - - virtual void DoGenerateDataBase(std::shared_ptr input, - const std::string& db_path) = 0; - - virtual void DoLoadDataBase(const std::string& db_path) = 0; - - protected: - InputConfig input_; - std::string db_path_; - - private: - bool is_ready_ = false; -}; - -} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_seal_pir_server.cc b/psi/kwpir/server/kw_seal_pir_server.cc deleted file mode 100644 index a9cb5e50..00000000 --- a/psi/kwpir/server/kw_seal_pir_server.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "psi/kwpir/server/kw_seal_pir_server.h" - -#include "yacl/base/exception.h" - -namespace psi::kwpir { - -void KeywordSealPirServer::DoQuery(const KeywordPirServerRequest* request, - KeywordPirServerResponse* response) { - (void)request; - (void)response; -} - -} // namespace psi::kwpir \ No newline at end of file diff --git a/psi/kwpir/server/kw_seal_pir_server.h b/psi/kwpir/server/kw_seal_pir_server.h deleted file mode 100644 index fa20da24..00000000 --- a/psi/kwpir/server/kw_seal_pir_server.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2024 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "psi/kwpir/server/kw_pir_server.h" - -namespace psi::kwpir { - -class KeywordSealPirServer : public KeywordPirServer { - public: - KeywordSealPirServer() = default; - - void DoQuery(const KeywordPirServerRequest* request, - KeywordPirServerResponse* response) override; - - protected: -}; - -} // namespace psi::kwpir \ No newline at end of file From 70aaa9e0061aaacd86ef9cff36f7ceddbb4e9ebb Mon Sep 17 00:00:00 2001 From: huocun Date: Wed, 20 Nov 2024 19:23:30 +0800 Subject: [PATCH 3/3] repo-sync-2024-11-20T19:23:24+0800 --- .../config/rr22_receiver_recovery.json | 43 ------------------- psi/apsi_wrapper/cli/entry.h | 6 +-- psi/apsi_wrapper/cli/sender.cc | 12 +++++- psi/proto/pir.proto | 4 +- 4 files changed, 16 insertions(+), 49 deletions(-) delete mode 100644 benchmark/docker-compose/config/rr22_receiver_recovery.json diff --git a/benchmark/docker-compose/config/rr22_receiver_recovery.json b/benchmark/docker-compose/config/rr22_receiver_recovery.json deleted file mode 100644 index c6680553..00000000 --- a/benchmark/docker-compose/config/rr22_receiver_recovery.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "psi_config": { - "protocol_config": { - "protocol": "PROTOCOL_RR22", - "role": "ROLE_RECEIVER", - "broadcast_result": false, - "rr22_config": { - "bucket_size": 1000000 - } - }, - "input_config": { - "type": "IO_TYPE_FILE_CSV", - "path": "/data/receiver_input.csv" - }, - "output_config": { - "type": "IO_TYPE_FILE_CSV", - "path": "/tmp/rr22_receiver_recovery_output.csv" - }, - "keys": ["id_0", "id_1"], - "debug_options": { - "trace_path": "/tmp/rr22_receiver_recovery.trace" - }, - "skip_duplicates_check": true, - "disable_alignment": true, - "recovery_config": { - "enabled": true, - "folder": "/tmp/rr22_receiver_cache" - } - }, - "link_config": { - "parties": [ - { - "id": "receiver", - "host": "0.0.0.0:5300" - }, - { - "id": "sender", - "host": "psi-sender:5300" - } - ] - }, - "self_link_party": "receiver" -} diff --git a/psi/apsi_wrapper/cli/entry.h b/psi/apsi_wrapper/cli/entry.h index aa426814..bbc4a9ee 100644 --- a/psi/apsi_wrapper/cli/entry.h +++ b/psi/apsi_wrapper/cli/entry.h @@ -50,8 +50,8 @@ struct ReceiverOptions { // experimental bucketize bool experimental_enable_bucketize = false; - size_t experimental_bucket_cnt; - size_t query_batch_size; + size_t experimental_bucket_cnt = 10; + size_t query_batch_size = 1; }; struct SenderOptions { @@ -89,7 +89,7 @@ struct SenderOptions { size_t experimental_bucket_cnt; std::string experimental_bucket_folder; int experimental_db_generating_process_num = 8; - int experimental_bucket_group_cnt = 1024; + int experimental_bucket_group_cnt = 512; }; int RunReceiver(const ReceiverOptions& options, diff --git a/psi/apsi_wrapper/cli/sender.cc b/psi/apsi_wrapper/cli/sender.cc index 6cab2394..fe8c89f5 100644 --- a/psi/apsi_wrapper/cli/sender.cc +++ b/psi/apsi_wrapper/cli/sender.cc @@ -31,10 +31,13 @@ DEFINE_string( DEFINE_string( db_file, "examples/pir/apsi/data/db.csv", - "Path to a CSV file describing the sender's dataset (an item-label pair on " - "each row) or a file containing a serialized SenderDB; the CLI will first " + "A file containing a serialized SenderDB; the CLI will first " "attempt to load the data as a serialized SenderDB, and – upon failure – " "will proceed to attempt to read it as a CSV file"); +DEFINE_string( + source_file, "examples/pir/apsi/data/db.csv", + "Path to a CSV file describing the sender's dataset (an item-label pair on " + "each row)"); DEFINE_string( params_file, "examples/pir/apsi/parameters/1M-256.json", "Path to a JSON file describing the parameters to be used by the sender"); @@ -67,6 +70,8 @@ DEFINE_bool(experimental_enable_bucketize, false, "Whether to split data in buckets and Each bucket would be a " "seperate SenderDB."); DEFINE_uint64(experimental_bucket_cnt, 0, "The number of bucket to fit data."); +DEFINE_uint64(experimental_db_generating_process_num, 0, + "The number of process num to generate db."); DEFINE_string(experimental_bucket_folder, "", "Folder to save bucketized small csv files and db files."); @@ -90,11 +95,14 @@ int main(int argc, char *argv[]) { options.channel = FLAGS_channel; options.streaming_result = FLAGS_streaming_result; options.db_file = FLAGS_db_file; + options.source_file = FLAGS_source_file; options.params_file = FLAGS_params_file; options.sdb_out_file = FLAGS_sdb_out_file; options.experimental_enable_bucketize = FLAGS_experimental_enable_bucketize; options.experimental_bucket_cnt = FLAGS_experimental_bucket_cnt; + options.experimental_db_generating_process_num = + FLAGS_experimental_db_generating_process_num; options.experimental_bucket_folder = FLAGS_experimental_bucket_folder; return psi::apsi_wrapper::cli::RunSender(options); diff --git a/psi/proto/pir.proto b/psi/proto/pir.proto index a81e554a..3212d248 100644 --- a/psi/proto/pir.proto +++ b/psi/proto/pir.proto @@ -45,7 +45,9 @@ message KeywordPirServerConfig { uint32 server_port = 4; - reserved 5, 9; + uint32 num_threads = 5; + + reserved 6, 9; KeywordPirServerBackendConfig backend = 10; }