diff --git a/.bazelrc b/.bazelrc index bf1ea9f..227f97a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -14,11 +14,19 @@ common --experimental_repo_remote_exec common --experimental_remote_download_regex='.*libserving\.so$|.*_pb2\.py$|.*\/secretflow_serving$|.*sf_serving\.tar\.gz$|.*\/simple_feature_service$' +common --experimental_cc_shared_library build --incompatible_new_actions_api=false build --copt=-fdiagnostics-color=always build --enable_platform_specific_config +# default off CUDA build +build --@rules_cuda//cuda:enable=false + +# Only on when asked +build:gpu --@rules_cuda//cuda:archs=compute_80:compute_80 +build:gpu --@rules_cuda//cuda:enable=true + build --cxxopt=-std=c++17 build --host_cxxopt=-std=c++17 diff --git a/.ci/accuracy_test.py b/.ci/accuracy_test.py index b6efbdb..8466548 100644 --- a/.ci/accuracy_test.py +++ b/.ci/accuracy_test.py @@ -54,7 +54,7 @@ def dump_json(obj, filename, indent=2): json.dump(obj, ofile, indent=indent) -def is_approximately_equal(a, b, epsilon) -> bool: +def is_approximately_equal(a, b, epsilon=0.0001) -> bool: return abs(a - b) < epsilon @@ -328,7 +328,7 @@ def _make_request_body(self): return json.dumps(body_dict) - def exec(self): + def exec(self, epsilon=0.0001): try: self.start_server() @@ -364,7 +364,7 @@ def exec(self): ] ) assert is_approximately_equal( - expect_score, s, 0.0001 + expect_score, s, epsilon ), f'result not match, {s} vs {expect_score}' finally: self.stop_server() @@ -459,3 +459,36 @@ def exec(self): query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'], score_col_name='pred', ).exec() + + AccuracyTestCase( + service_id="phe_sgd", + parties=['alice', 'bob'], + case_dir='.ci/test_data/phe_sgd', + package_name='s_model.tar.gz', + input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'}, + expect_csv_name='predict.csv', + query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'], + score_col_name='pred', + ).exec() + + AccuracyTestCase( + service_id="phe_sgd_no_feature", + parties=['alice', 'bob'], + case_dir='.ci/test_data/phe_sgd_no_feature', + package_name='s_model.tar.gz', + input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'}, + expect_csv_name='predict.csv', + query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'], + score_col_name='pred', + ).exec() + + AccuracyTestCase( + service_id="phe_glm", + parties=['alice', 'bob'], + case_dir='.ci/test_data/phe_glm', + package_name='s_model.tar.gz', + input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'}, + expect_csv_name='predict.csv', + query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'], + score_col_name='predict_score', + ).exec(0.1) diff --git a/.ci/inferencer_test.py b/.ci/inferencer_test.py new file mode 100644 index 0000000..a789f95 --- /dev/null +++ b/.ci/inferencer_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from importlib import resources +import asyncio +import os + +current_file_path = os.path.abspath(__file__) +code_dir = os.path.dirname(os.path.dirname(current_file_path)) + +alice_serving_config_file_path = os.path.join( + code_dir, + "secretflow_serving/tools/inferencer/example/alice/serving.config", +) +alice_inference_config_file_path = os.path.join( + code_dir, + "secretflow_serving/tools/inferencer/example/alice/inference.config", +) +bob_serving_config_file_path = os.path.join( + code_dir, + "secretflow_serving/tools/inferencer/example/bob/serving.config", +) +bob_inference_config_file_path = os.path.join( + code_dir, + "secretflow_serving/tools/inferencer/example/bob/inference.config", +) + + +async def run_process(command): + process = await asyncio.create_subprocess_exec( + *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + if process.returncode == 0: + print( + f"Process {' '.join(command)} completed successfully:\n{stdout.decode().strip()}" + ) + else: + print( + f"Process {' '.join(command)} failed with exit code {process.returncode}:\n{stderr.decode().strip()}" + ) + + +async def main(): + with resources.path( + 'secretflow_serving.tools.inferencer', 'inferencer' + ) as tool_path: + alice_command = [ + str(tool_path), + f'--serving_config_file={alice_serving_config_file_path}', + f'--inference_config_file={alice_inference_config_file_path}', + ] + bob_command = [ + str(tool_path), + f'--serving_config_file={bob_serving_config_file_path}', + f'--inference_config_file={bob_inference_config_file_path}', + ] + commands = [alice_command, bob_command] + + tasks = [run_process(command) for command in commands] + + await asyncio.gather(*tasks) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/.ci/test_data/bin_onehot_glm/alice/alice.csv b/.ci/test_data/bin_onehot_glm/alice/alice.csv index d83612e..e839d87 100644 --- a/.ci/test_data/bin_onehot_glm/alice/alice.csv +++ b/.ci/test_data/bin_onehot_glm/alice/alice.csv @@ -18,4 +18,3 @@ id,f1,f2,f3,f4,f5,f6,f7,f8,b1,o1,y 17,0.07245618290940148,0.15828585207480095,0.6058844851447147,-0.7307832385640813,-0.7047890815652045,0.7604535581160456,0.800552689601107,0.5620156263473772,0.18038921596661117,D,1.0 18,0.9462315279587412,-0.16484385288965275,0.9134186991705504,-0.8007932195395477,0.9007963088557205,-0.2762376803205897,0.08557406326172878,0.2904407156393314,0.39477993406748413,D,0.0 19,-0.24293124558329304,-0.6951771796317134,0.7522260924861339,-0.894295158530586,-0.14577787102095408,-0.2960307812622618,0.28186485638698366,-0.2308367011118997,0.014778052261847086,C,0.0 -20,0.104081262546454,-0.3402694280442802,0.6262946992127485,-0.8635426536480353,0.6833201820883588,0.011286885685380721,0.4785141057961586,0.9729862354474565,0.08367496429611809,D,0.0 diff --git a/.ci/test_data/phe_glm/alice/alice.csv b/.ci/test_data/phe_glm/alice/alice.csv new file mode 100644 index 0000000..0dfa098 --- /dev/null +++ b/.ci/test_data/phe_glm/alice/alice.csv @@ -0,0 +1,33 @@ +id,f1,f2,f3,f4,b1,o1,y,unused1 +0,-0.1561563606294591,0.4091436724298469,0.7527352529453377,-0.4663496200949144,0.4168724773199035,C,1.0,0.6044702778742779 +1,-0.9404055611238592,-0.9083512326886756,-0.3706442384030441,0.2819235971596161,0.3517699625436855,D,0.0,0.7281280567505588 +2,-0.5627240503927933,-0.5442034486969063,0.3108773305897601,-0.7768956528082471,0.305838882862975,A,0.0,0.621498633148778 +3,0.0107105762067247,-0.4212240727957856,-0.2087361978786714,-0.13046949866179,0.4936165318157521,C,1.0,-0.4663885808110559 +4,-0.9469280606322728,-0.840416046152745,0.829095179481087,-0.0925525873415871,0.3269881588553663,C,1.0,0.5747490182709423 +5,-0.602324698626703,-0.5344182272779396,-0.0822962948252024,0.9076318550421604,0.0039115535760789,B,1.0,-0.7838087471940858 +6,0.2997688755590464,-0.7979971411805418,-0.4702396670038951,0.7517058807563881,0.4085520675577308,D,0.0,0.7443335658121795 +7,0.0898829612064333,-0.4440527937798157,-0.5067449846120331,-0.4732218984978185,0.1496893760999889,C,1.0,0.7171865026755633 +8,-0.5591187559186066,0.2713688885288003,0.1227362683263015,0.0011722261005966,0.3316943574830386,D,1.0,-0.5551325649086711 +9,0.1785313677518174,-0.2703356420598315,-0.4745167829541294,-0.6426962389397373,0.4694650019635519,D,0.0,0.6331732111938579 +10,0.6188609133556533,-0.2596380657662347,0.169171980447081,0.825255678689641,0.0671455571966838,D,1.0,-0.0793935306421158 +11,-0.987002480643878,-0.5809859384570246,0.795645767204954,0.7410371396735338,0.0577143352095511,B,1.0,-0.3896182653227988 +12,0.6116385036656158,-0.4660443559017733,-0.2011989897192054,-0.4031104171027342,0.0535179888547088,D,0.0,0.5906909983057236 +13,0.3962787899764537,0.873309175424988,-0.5613584816854333,0.2778989897320103,0.2766118204424079,B,1.0,-0.5448090251844593 +14,-0.3194989669640162,0.2960707704931871,0.9950752129902204,0.2179404228763446,0.1361741061574081,A,1.0,-0.952671130597097 +15,-0.6890410003764369,0.2182620113339763,0.019052587352929,-0.6943214629007304,0.3024149135151119,C,0.0,-0.6137404233445827 +16,0.9144261444135624,-0.657722703603806,-0.8181811756524122,0.5250216001503025,0.3588060935693989,C,0.0,-0.3434760976045869 +17,-0.3268109097747464,0.4582535959006983,-0.9057672491505308,0.0787580602392514,0.1017986561637264,A,1.0,0.7287058840605727 +18,-0.8145083132397042,-0.6731950124761432,-0.7807017392986817,0.5572529572611165,0.3171189794425398,B,1.0,0.9337782080967224 +19,-0.806567246333072,-0.2410891164847044,0.2548920834061801,0.0607073443903549,0.1319919508152047,C,1.0,-0.4417500145562572 +20,0.6949887326949196,0.9790467012731904,0.5841587287259282,-0.998856207744113,0.2442659260746882,B,1.0,0.2829634772152554 +21,0.2074520627337821,0.2799995197081857,-0.1556800664006319,-0.3516878859906538,0.4526682455396616,B,0.0,-0.2006432312798782 +22,0.6142565465487604,0.1138994875492924,-0.8729445876960857,-0.9610465152283354,0.4230518566474277,B,1.0,0.9622993743965202 +23,0.4594635733876357,0.3692285019797492,-0.2367614269869264,0.8581972325292342,0.0461492338563667,A,1.0,0.0724314649574437 +24,0.0724561829094014,0.6857038403796192,0.9922427604801936,0.7574437556463685,0.2117878862818631,A,1.0,0.8784742806494314 +25,0.9462315279587412,0.5519998230924896,0.058228690198274,0.6633310587223589,0.1383401119861258,B,1.0,-0.7693164962971448 +26,-0.242931245583293,-0.5419038560717913,0.9421567552272364,-0.3849717491946771,0.0017728445438911,D,0.0,0.940801222044456 +27,0.104081262546454,-0.9357995121919244,0.7215594044689961,-0.8841496670116249,0.3855596115098135,A,0.0,-0.6428643676550727 +28,0.6588093285059897,-0.3690939038818361,-0.9770379561143608,0.7560191984080811,0.3185566886506898,D,0.0,0.9250686315231108 +29,0.2370395047284921,-0.4645182480485945,0.4414436387203893,0.8938988905959881,0.1309776312171741,D,1.0,-0.4690672749540627 +30,0.7234138006215545,-0.5780343128273471,0.3634207380531495,-0.8286930958642424,0.3706154541739654,B,1.0,-0.7831949055705778 +31,0.154704290513524,0.8858194286701089,0.0739406608175903,-0.0280190733667724,0.2758402105631956,D,1.0,-0.1308724828707113 diff --git a/.ci/test_data/phe_glm/alice/s_model.tar.gz b/.ci/test_data/phe_glm/alice/s_model.tar.gz new file mode 100644 index 0000000..6103d4d Binary files /dev/null and b/.ci/test_data/phe_glm/alice/s_model.tar.gz differ diff --git a/.ci/test_data/phe_glm/bob/bob.csv b/.ci/test_data/phe_glm/bob/bob.csv new file mode 100644 index 0000000..9ba5fa4 --- /dev/null +++ b/.ci/test_data/phe_glm/bob/bob.csv @@ -0,0 +1,33 @@ +id,f5,f6,f7,f8,b2,o2,unused2 +0,-0.8615749630632328,-0.5715263852591228,0.9379925145695024,0.1900702129000553,0.2138434594903396,D,0.4570901213054086 +1,0.5212043305144631,-0.73537630254995,0.8527339660162552,0.3504251072081803,0.0048348498041699,D,-0.3726453716100175 +2,0.5316688586139755,0.871028481161342,0.6973914688286109,-0.5295922099981376,0.0376219300368835,B,0.2124177066122865 +3,-0.7432170710004744,0.1420861866505689,-0.667377778792172,-0.7602267721057516,0.4415531966500715,B,0.0228461193389561 +4,-0.0494352438025373,-0.0546579473764117,-0.028717749098563,0.780574628258875,0.4519642857799466,A,-0.2296091333105456 +5,0.0996071869898878,0.5692388485815068,-0.5725054016016367,-0.5075693044227503,0.2727951446027611,A,0.153176086993199 +6,-0.4698867421198818,0.6149939955332868,-0.1979194149010947,0.1890383070668824,0.4172975099430083,D,-0.4905549877228361 +7,0.7448660821705149,-0.6191801712762446,-0.8827292000556421,0.2387630206642061,0.291254783244897,B,0.4175705676683412 +8,-0.1537241195982261,-0.8061383715423533,-0.2420537620461678,-0.161550169328255,0.0740468927837413,B,-0.9966174435627412 +9,-0.5764035891158359,-0.1378976351872449,0.970617687559452,0.1673445785824494,0.0637227596410693,D,0.8511503309981654 +10,0.0785921775589166,-0.1528427539601584,-0.4695938836556961,0.0455654310639177,0.1541291749650668,A,0.0769039941855838 +11,0.4598621381799523,-0.06595066392665,0.5681412038971387,0.8694125154728545,0.449490744371295,B,0.438859998289691 +12,-0.5976978732206082,0.4581516989197012,-0.0899832653217134,-0.5914816011529271,0.3980611524440208,A,0.483900155678953 +13,-0.3765674173982101,0.346729094586603,-0.1539850280196741,0.4323836015788296,0.4303512910004514,D,0.3412570088659989 +14,0.9902987133217892,0.9683304227319324,0.9146352817193464,-0.522628094768308,0.4494623182632373,B,-0.2715570564374716 +15,0.299756115278907,-0.8031642576960822,0.9908453789854276,-0.208428306417491,0.105038269169877,D,-0.8600523777473796 +16,-0.1237998321709918,-0.1947574357954624,0.1115366468112364,0.3433804459199425,0.1247648696114622,C,0.3284753698225446 +17,0.0351516820711812,-0.3213947892100737,0.436816550592652,-0.4000058404024755,0.0513968108358928,D,-0.3395999279148072 +18,-0.7579916082634686,0.7233450727055821,-0.6904063494518717,-0.3676456074562919,0.3900581209357213,D,-0.3721687098832806 +19,-0.5506053259368853,-0.5026873321594287,-0.4065843490108716,0.5037289848288042,0.4420673507255044,D,0.696030559012671 +20,-0.3238288757050893,-0.619582183118377,0.9374187299383177,-0.8549137710136854,0.2031886949160584,B,0.4395085260279003 +21,0.1766174369144666,-0.1027729043337362,0.1583605816325124,-0.0834289547628277,0.3103307550753564,D,-0.3993554635774716 +22,-0.539770534806846,-0.1562367203311916,0.0843904027485484,0.9969088817088848,0.0772766691661023,D,-0.3814306755826935 +23,-0.559565231096881,-0.442909710666119,0.4959511207581282,0.9921928957101888,0.4649405078468372,C,-0.1832141827615663 +24,-0.8580138279819349,-0.500387104235799,-0.8856694541850338,-0.853478557800734,0.432302848109982,B,-0.1951992258845507 +25,0.2622059145401978,0.8465311985520256,0.1683551889179424,-0.5736913754659192,0.4881030164654814,C,-0.4086895949481059 +26,-0.5421164323776912,-0.113738509893086,0.0057007658390271,-0.4695991704991973,0.4053858599701984,D,-0.7454244018816936 +27,0.8108400260122559,0.7226982095236612,0.7054397840965707,0.8665187559874181,0.4407081023316622,B,-0.1591073324541834 +28,0.719270800507493,0.1006506248996961,-0.6851345441210335,0.7617283473728791,0.0123931809490943,C,0.880727341460366 +29,-0.8582853002226931,-0.8988233409502375,0.9215578065489008,0.7585404849690855,0.368282235877541,D,0.3546358905454658 +30,-0.5239907312620096,0.9985649368254532,-0.8397770695188262,-0.2609458225222321,0.1660927339732143,A,0.8056110914651653 +31,0.3379555565925611,0.6720551701599038,-0.6283500780385536,-0.6845063352855361,0.4654079430241628,C,0.231029831902761 diff --git a/.ci/test_data/phe_glm/bob/s_model.tar.gz b/.ci/test_data/phe_glm/bob/s_model.tar.gz new file mode 100644 index 0000000..783dd86 Binary files /dev/null and b/.ci/test_data/phe_glm/bob/s_model.tar.gz differ diff --git a/.ci/test_data/phe_glm/predict.csv b/.ci/test_data/phe_glm/predict.csv new file mode 100644 index 0000000..7881fea --- /dev/null +++ b/.ci/test_data/phe_glm/predict.csv @@ -0,0 +1,33 @@ +id,pred,y,predict_score +0,0.47686505,1.0,516.2334798971686 +1,0.16220368,0.0,560.9372970797582 +2,0.26960367,0.0,542.3182257710636 +3,0.5657716,1.0,505.9261270108796 +4,0.17707317,1.0,557.889809489704 +5,0.27827334,1.0,541.0604345211162 +6,0.6727009,0.0,492.7743294058237 +7,0.58304656,1.0,503.8869425254623 +8,0.29150468,1.0,539.1862235038279 +9,0.5982337,0.0,502.07438558814715 +10,0.7883281,1.0,475.62210737259727 +11,0.17442402,1.0,558.4174829840767 +12,0.7779472,0.0,477.38605063822337 +13,0.7261041,1.0,485.4304663269026 +14,0.39520478,1.0,525.8383646190198 +15,0.23407295,0.0,547.7664128967078 +16,0.8454458,0.0,464.5294473247045 +17,0.36718696,1.0,529.2667277402132 +18,0.20782131,1.0,552.171084113332 +19,0.21610352,1.0,550.7402489345043 +20,0.8168119,1.0,470.4278724665369 +21,0.6409019,0.0,496.84692465294745 +22,0.7754037,1.0,477.8091725130769 +23,0.7400456,1.0,483.37432514560277 +24,0.63102967,1.0,498.07737805471197 +25,0.8610095,1.0,460.94058571637396 +26,0.42536157,0.0,522.2407333155013 +27,0.57785815,0.0,504.5016875280319 +28,0.77958846,0.0,477.111178457454 +29,0.66062826,1.0,494.3419931815475 +30,0.8147475,1.0,470.8242355794075 +31,0.6115879,1.0,500.4620059054651 diff --git a/.ci/test_data/phe_sgd/alice/alice.csv b/.ci/test_data/phe_sgd/alice/alice.csv new file mode 100644 index 0000000..b99a909 --- /dev/null +++ b/.ci/test_data/phe_sgd/alice/alice.csv @@ -0,0 +1,33 @@ +id,a0,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13,a14,y +0,1.0970639814699807,-2.073335014697593,1.2699336881399383,0.9843749048031144,1.568466329243428,3.2835146709868264,2.652873983743168,2.532475216403245,2.2175150059646405,2.255746885296269,2.4897339267376197,-0.5652650590684639,2.833030865855184,2.4875775569611043,-0.2140016466689538,0 +1,1.8298206075464456,-0.3536324082438112,1.6859547105508974,1.9087082542365936,-0.8269624468508425,-0.4870716725758942,-0.0238458551987692,0.5481441558908369,0.0013923632994608,-0.8686524574634664,0.4992546006760562,-0.8762436030602548,0.263326965842778,0.7424019483418791,-0.6053508469797809,0 +2,1.5798881149312178,0.4561869517641946,1.5665031298586416,1.5588836327586924,0.9422104400684552,1.05292554434161,1.3634784515699176,2.037230755700812,0.939684816618985,-0.3980079103689868,1.2286759457296228,-0.7800833765050336,0.8509283007136554,1.181336055653447,-0.2970050119818975,0 +3,-0.7689092872596208,0.2537321117621929,-0.5926871666544732,-0.7644637923250287,3.283553480279431,3.402908991274548,1.9158971800569968,1.451707356849496,2.867382930831859,4.9109192850190375,0.3263734407153149,-0.1104090440232948,0.286593404544489,-0.2883781482770153,0.6897016600113287,0 +4,1.7502966326234184,-1.1518164326195182,1.7765731510760563,1.826229278440991,0.2803718299176319,0.5393404523102987,1.3710114342311053,1.4284927727540695,-0.0095604668949302,-0.562449981040552,1.2705427819622863,-0.7902437023297363,1.2731894116191806,1.1903567566057145,1.483067159789666,0 +5,-0.4763746652213425,-0.8353353034209873,-0.3871480674633165,-0.5056504544836544,2.237421483589421,1.2443354863901803,0.8663015959315467,0.8246556464496959,1.005401797785333,1.8900050384577884,-0.2550702935159049,-0.5926616519172156,-0.3213041853640514,-0.2892582166626024,0.1563467021771524,0 +6,1.170907672469935,0.1606494267038018,1.13812504737607,1.095294906735132,-0.1231362259485147,0.0882952423344605,0.3000723992322905,0.646935108208041,-0.0643246178668869,-0.7623321531499545,0.1498830707345162,-0.8049398878976097,0.1554102927156918,0.2986274649095823,-0.909029826096615,0 +7,-0.1185167780677197,0.3584501324528832,-0.072866839641968,-0.2189649110285938,1.6040490502192788,1.140102349631058,0.061025749450609,0.2819502582632787,1.403354628181551,1.6603531811406034,0.6436230014783456,0.2905609572730053,0.4900509855317937,0.2337224214725335,0.588030871174189,0 +8,-0.3201668573368246,0.5888297779724025,-0.1840803802864819,-0.3842072728811636,2.20183876261357,1.6840098087195687,1.2190962838971586,1.150691583078798,1.965599911493639,1.5724617295747656,-0.3568500160815189,-0.3898180042026168,-0.2277433999465317,-0.3524031233284771,-0.4366773415647225,0 +9,-0.4735345232598054,1.105438680046475,-0.329481787129124,-0.5090633776200244,1.5826994176337676,2.563358453378346,1.738872087519092,0.941760326219959,0.7972980240918982,2.783095594691288,-0.3882501432560169,0.6933453024665736,-0.4094196340641493,-0.3607637729915545,0.0360084898158162,0 +10,0.537556015047254,0.9192733099296918,0.4420106633418918,0.4064532537111665,-1.0176858312814026,-0.7135418515343508,-0.700684347306461,-0.404685551314783,-1.0354755617695854,-0.8261243357380614,-0.0926558426133296,-0.0541643832079764,-0.1980415633060492,0.0038045557379028,-1.004033677960828,0 +11,0.4693926079703736,-0.3257076027262936,0.4790818435567294,0.3586723298019898,0.0526424156721877,0.4711151264316007,0.1348489795302461,0.442130888521722,0.1109206652433743,-0.2803467735953673,0.3631873829198987,-0.4208432848459051,0.3455020472147793,0.3041278923195017,-0.4233434676188681,0 +12,1.4322007329313104,1.2822957816574192,1.6653596104315431,1.3313554236673744,0.0739920482576977,2.6808576257249923,1.4777286885979273,1.6219476402159576,2.1371942512057704,2.155096997212811,1.986249128939636,4.265788436187907,4.061201810939133,1.669113958365099,-1.3007123732560884,0 +13,0.4892736017011304,1.084495075908337,0.4832008635806006,0.3635073042451804,-0.8789132194755842,-0.078477776480135,0.1328401841539294,0.1217696280048353,0.1291753822340265,-1.3350441923854053,-0.0067566441359675,-0.2519278680102239,0.0182868135587977,-0.08266216314603,0.90937723326928,0 +14,-0.1128364941446465,0.7726680809627255,0.0671798411696414,-0.2178272699831373,1.1912894868994104,2.3681582154476257,1.5568250065403952,0.8081474977596147,0.939684816618985,1.9878197184262187,-0.6968375999709101,-0.0868225733588063,-0.3985289606293058,-0.4648318595872295,-0.2040012412095631,0 +15,0.1172150047398245,1.9199121743074004,0.1961051679168003,0.0111229904150014,1.2482218404607712,1.0453449525773104,0.9428869196536188,0.6376492745698704,1.794005571781509,1.1301692636305567,-0.1269433378038733,-0.3335733433872988,0.0064060789026048,-0.1713290529939307,-0.4780123507968712,0 +16,0.1569769922013383,0.1955554336006982,0.1141366694417695,0.0842164275855874,0.1643721595363581,-0.612909495863271,-0.1864327309693982,0.0946859465601709,-0.8237208446780198,-0.5071634227975256,0.2437225312560043,0.0419958433472446,0.1628357518758122,0.1113929158759261,-0.4410108505971252,0 +17,0.5687975766241574,0.3235441255559867,0.664437744630919,0.4092973563248079,1.468834710511046,1.8545731234163136,1.0470931798000451,1.389801799261692,1.2865244394413775,1.5256807956768208,0.5920112981915273,-0.2609995874965657,0.4890609243104445,0.3045679265122953,-0.0049931725676859,0 +18,1.6139698184696574,0.6656229931455752,1.5665031298586416,1.7209974817362566,0.1387526004337456,-0.0310990779532612,0.7420073820219539,1.188092857454763,-0.8383246182705409,-1.254240761107136,1.274151991982344,-0.3626028457435921,1.484567482377281,1.585507461734324,-0.1823336960475497,0 +19,-0.1667991914138439,-1.1471622983665986,-0.1857279882960305,-0.251956501346835,0.1017465706188614,-0.4368502521374081,-0.2782095697248652,-0.0286092889688724,0.2679112313629833,-0.7283096557696301,-0.4882252608116023,-0.7769989918796776,-0.40001405246133,-0.3691244226546321,0.4736929020884875,1 +20,-0.2974457216445311,-0.833008236294527,-0.2611060547328675,-0.3836384523584352,0.792763011969882,0.4294218717279516,-0.5413617640223466,-0.4596267336739594,0.5672885900096801,0.7530865843319685,-0.7939253495104497,-0.8512056572779519,-0.7341597146667573,-0.5647196213513657,-0.9813660922528752,1 +21,-1.31308048709005,-1.5939591866468776,-1.3028062187698126,-1.0835721055756025,0.4298192580162052,-0.7470859700913773,-0.7437478981862491,-0.7263365109480825,0.012345193493852,0.8863413657382362,-0.4615171066631787,-0.4355394704137786,-0.4737736134518613,-0.5420578604224978,0.8550416969399234,1 +22,0.3444263616627596,-1.1704329696311964,0.4337726232941499,0.1408140695970527,0.7785299235795423,2.0687248407577834,1.492794653920302,1.2546413318616525,2.5899112325739453,1.0663770810424489,0.1213703115760642,-0.9203321597638752,0.2563965372933319,0.1006120781524841,-0.0839963756968734,0 +23,1.99738898327711,0.8727319674004963,1.8630725715773444,2.130548258100629,-0.1480441306316104,-0.040574817658636,0.2624074859263535,0.9647169704921028,-0.1555982028201481,-1.420100435836214,1.034139525648538,-0.163025017044076,0.711329668503388,1.1804559872678595,-0.7710242307570218,0 +24,0.7164849586240649,0.4864388244081716,0.7426991250844659,0.7102034128480763,1.120124044947709,0.7838145367089672,0.7997602490910576,1.1034885954180975,0.6695150051573325,0.0712190326679779,1.449559598957125,-0.5717966970986299,1.2816049320006508,1.3698907072654836,-0.3310063905438262,0 +25,0.8556519147393631,-0.6724406045688023,0.9898403265167186,0.7332406440185723,1.5826994176337676,2.3359407004493518,1.683630214670384,2.3519173401054827,4.484750856203646,1.60648422695509,2.312882635754815,-0.4369909455315933,2.183055674039293,1.5635057520946465,0.3293537166246132,0 +26,0.1285755725859715,0.5213448313050689,0.224114504079122,-0.0286944461759792,0.6433155838713087,1.5627203404907717,0.6742105380712667,1.0036658838077632,1.607807458476857,0.913275842830992,-0.5438070951204835,-0.4239276694712613,-0.3742724607062452,-0.4243487138502227,-0.8630279609834172,0 +27,1.2731527830852554,0.2234802391182159,1.241100547972842,1.2488764478717715,-0.1395042775974066,0.0428116917486617,0.7558178502341308,0.7323131897145546,-0.4184661274855406,-0.8232891276230346,1.6159441808817636,1.146931276783655,1.369225350090074,1.170555217930005,1.2363904917913593,0 +28,0.333065793816613,1.3916679366010287,0.4296536032702793,0.220448942779014,0.8425788213360734,1.238650042566955,0.998128792502327,0.9954118094627226,0.4175999106863312,0.3689158847458118,0.1220921535800756,-0.3716745652299335,0.3128300269102487,0.0695896675605388,-0.60268407219061,0 +29,0.9777780190854404,-0.986594666640873,0.9486501262780098,0.8538305948369708,0.1501390711460184,0.2152701543864822,0.1249305523596827,0.7895758304832735,-0.2651265047640616,-0.1853673017419629,0.7042577298153069,-0.7154927337622812,0.8855804434608852,0.4568197572188642,-0.4713454138239439,0 +30,1.2788330670083286,1.354434862577672,1.352314088617356,1.2318118321899227,0.7144810258230102,1.598728151371196,1.7966249545881956,1.9469518175519296,1.355892364005855,-0.1173223069813161,1.535819718436493,0.4520375641298862,1.340513574670941,1.4226948104007096,-0.2643370208145543,0 +31,-0.6496233248750801,-0.1372151654830511,-0.5782705965709247,-0.6094601998815683,1.0347255146056669,0.8956282652323896,0.4143226362603002,0.074824580167417,1.786703684985248,2.153679393155298,0.2790927894525651,-0.3390163750791036,0.3014443228647304,0.0145853934613448,-0.4966797743210672,0 diff --git a/.ci/test_data/phe_sgd/alice/s_model.tar.gz b/.ci/test_data/phe_sgd/alice/s_model.tar.gz new file mode 100644 index 0000000..1949ed8 Binary files /dev/null and b/.ci/test_data/phe_sgd/alice/s_model.tar.gz differ diff --git a/.ci/test_data/phe_sgd/bob/bob.csv b/.ci/test_data/phe_sgd/bob/bob.csv new file mode 100644 index 0000000..b813309 --- /dev/null +++ b/.ci/test_data/phe_sgd/bob/bob.csv @@ -0,0 +1,33 @@ +id,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14 +0,1.3168615683959484,0.72402615808036,0.6608199414286064,1.1487566671861758,0.907083080997336,1.886689625179276,-1.3592934737640827,2.303600623622561,2.0012374893299207,1.307686271071539,2.616665023512603,2.109526346572256,2.296076127561788,2.750622244124955,1.9370146123781784 +1,-0.6929262695890712,-0.4407800577847933,0.2601620674590054,-0.8054503802819919,-0.0994437403202747,1.805927438479428,-0.3692032217294088,1.535125992343437,1.8904889885289908,-0.3756119566608087,-0.4304442186927949,-0.1467489683154696,1.087084295170027,-0.2438896675666794,0.2811899865404747 +2,0.8149735042940163,0.2130764345824376,1.42482746628562,0.2370355353748186,0.2935594041175298,1.5118702458799815,-0.0239743838488975,1.3474752102869063,1.456284548880901,0.5274074050914401,1.0829321669453351,0.8549739441841201,1.9550003461313663,1.1522550000669671,0.2013912093916699 +3,2.7442804054965437,0.8195183841461625,1.115007005037871,4.732680372580089,2.0475108774169515,-0.2814644639166429,0.1339840938605815,-0.2499393042673343,-0.5500212283270541,3.3942746991980925,3.8933974345995,1.9895882583898328,2.175786008218023,6.046041349536007,4.935010337204809 +4,-0.0485198799348087,0.8284707803398315,1.1442047448413235,-0.3610922722145709,0.4993281342177852,1.2985752399803827,-1.4667703761231092,1.3385394587604047,1.2207242455900345,0.2205561656610641,-0.3133945108502493,0.6131787584083571,0.7292592566157908,-0.8683529835650433,-0.3970996192243676 +5,0.4455436486466032,0.1600251978792141,-0.0691235536577115,0.1341188073483042,0.4868458399286153,-0.1654982471168611,-0.3138363326353646,-0.1150094562171625,-0.2443202078622653,2.04851283483916,1.7216164423470517,1.263243196357085,0.905887786285116,1.7540693875058049,2.2418016084326413 +6,-0.6515680104091791,-0.3101413874031051,-0.2280890259209542,-0.8296660809941128,-0.6112178061762416,1.3689833001802505,0.3228828919461442,1.368325297182076,1.2752195396349364,0.5186402268220005,0.0212149800474625,0.5095522502187443,1.196715796344091,0.2624756637998691,-0.0147304787196769 +7,0.2689327040405782,-0.2325539537246406,0.4353485062797216,-0.6880042318282048,0.6116687828203147,0.1637629755825187,0.4010479118436185,0.0994485804188717,0.0288594274466947,1.4479611233825669,0.7247855065358073,-0.0210538519002913,0.6241957346573126,0.4776404851153671,1.7264345060132755 +8,0.533290225555293,0.1205683405811914,0.0752430487038047,0.1074815365649713,-0.01736319908543,-0.161356596516869,0.8228133317070733,-0.0316091086364824,-0.2483634070978549,1.6627569909838305,1.8183096792604587,1.2800345287026242,1.39161624287576,2.3898571677839318,1.2886495480441378 +9,2.6095866154647336,1.5098476017468596,0.4093949597877636,-0.3211363660395712,2.3773460477247146,-0.2441896085167131,2.443109056665133,-0.2862780271417735,-0.2974091717382667,2.3202953611917776,5.112877271198196,3.9954328451526617,1.62001520365506,2.370443800447194,6.846856039728261 +10,-0.9059213043655148,-0.6924418618957103,-0.6821138798646442,-0.7194846427539621,-0.2847868979473433,0.604848764481689,1.3357712747842485,0.4926216475849352,0.4736113433615398,-0.6254765373398288,-0.6308282294015006,-0.6058719698777817,-0.2262097293109468,0.0764308934894756,0.0318188079504589 +11,0.8457127509817739,-0.1320881742179109,0.1660804614256576,-0.0559744432418457,0.1320460810425129,0.8595602763812097,0.2610022511939773,0.870901795540163,0.7355403373192945,0.3169951266248968,1.9506267402998576,0.596387426062818,1.0109513082435937,1.4418377295066225,1.1556515861294627 +12,3.2131936413334268,1.8901586548630944,4.720927870764283,2.941929304918738,3.4213197519098357,0.971384842580999,0.6941667364591471,1.323646539549569,0.7935514567864481,-1.2567133727394586,0.8653723838901689,0.4399881590729394,0.9454769394868612,0.4452848728874726,1.0171120424683429 +13,0.3231455572898961,0.6172605442151224,1.317769087006293,1.1221193964028424,-0.2999169516311856,0.1182048189826047,0.3228828919461442,0.1411487542092119,-0.0071777831313855,-0.8446559940758109,-0.3935481151337316,-0.1918456894720605,-0.041206571079714,-0.1484406114943913,-1.1679336401548384 +14,1.8936416153371445,0.766467147442939,0.7273259043142488,-0.1128813399153303,1.625760630979847,-0.2566145603166899,1.031253384767004,0.0458340712598631,-0.321492575880691,1.434810355978408,3.296698380489132,2.025089932491831,1.6169698841780027,1.1247527296732565,3.2780773950178173 +15,0.9457550265655666,0.5144737731026268,-0.1453620964778381,-0.2388029836183596,0.6320943552935019,0.2465959875823632,1.86501359700673,0.5015573991114366,0.11007499470071,1.553167262615839,2.5664099859062928,2.0649093777683944,0.8617306538677845,2.131012269960775,2.779335037837786 +16,-0.7745249971602093,-0.3950233661282629,-0.1145422600186381,-0.7800238945342648,-0.646773432333271,0.5799988608817356,0.8472399004250337,0.4807073122162669,0.4525163908280294,0.6150791877858333,-0.427263520110117,0.0921677033439151,0.7048967007993319,0.2074711230124482,-0.0989625212656381 +17,-0.0261640641618941,-0.0004547931480373,0.1904119112618683,-0.4422148696001764,0.1312895783583206,0.971384842580999,0.9449461752968767,0.8798375470666647,0.7636669406973083,2.0397456565697203,1.0752984903469085,0.9893046029484328,1.411410819476633,1.3027085969266765,1.6765602702952729 +18,-0.3659724639101956,0.0668539634191777,0.55376156214928,-0.8454062864569916,-0.680059550437724,2.28842973337852,0.8472399004250337,2.369129468150238,2.667486406846625,0.8254914662523765,0.3863591773388819,1.2713989863534898,1.8910486371131627,-0.2147696165615749,-0.4320115842269701 +19,-0.6079741696519958,-0.2660425468935505,0.219609651065321,-0.089876424238815,-0.5654493937826185,-0.240047957916721,-1.0450049562596546,-0.225217058377347,-0.2977607542804917,0.509873048552561,-0.4896052123306031,-0.1592225294864416,0.2161229247316304,0.1233465312199229,-0.6292918944004043 +20,-0.3631779869385813,-0.4944944349468072,-0.86070672166243,-0.4555335049918428,-0.5181679760206112,-0.3663683012164831,-0.8447070927723773,-0.332743935079581,-0.4396243100683495,-0.0512263606915536,0.1484429233545774,-0.3990987058512859,-0.6361097309228636,0.4582271177786307,-0.1172497410289058 +21,-0.623623240693036,-0.3993337791103998,0.3915518965745426,-0.0329695275653308,-0.3127774972624515,-1.250610704314819,-1.6312426054907117,-1.25491349261453,-0.994421561699674,0.0013767089250812,-0.8871925351653367,-0.8804342413449825,-0.7969025993114908,-0.7292238509850971,-0.3444545926331425 +22,1.5538332155888428,1.0798010142213534,1.73951421750061,1.9587718560066243,0.2266089165665275,0.3729163308821253,-1.074316838721208,0.5313432375331081,0.1763483039101552,0.2906935918165793,2.170094942504631,1.7190079314502889,1.898661935805806,2.8573957644770065,0.8597311208693111 +23,-0.7203121439108915,-0.4888577410470897,-0.2297111225767015,-1.1759506011774437,-0.6838420638586846,2.671532413877799,1.6142341581690005,2.404872474256244,3.048953465160939,0.3389130722984951,0.0364823332443163,0.2077880203517701,1.313960596210798,-0.1274094635462595,-0.4813316617703289 +24,-0.3721203132477471,-0.1486666856876682,-0.0804782302479432,-0.7097983624691137,-0.3759454713924933,2.1103387575788557,0.957973678613122,2.077228251617857,2.3457883807105917,2.1098830827252346,0.6586269760161075,0.9466066435554904,1.4449093337242631,1.1522550000669671,0.6480426981551198 +25,0.6992821576691837,0.1799194116429228,1.97471823258398,0.3072610674399696,1.380275509959505,1.238521306280496,-0.6965192425500825,1.3444966264447396,1.020322196521685,0.9701499076981248,0.8946348108508052,0.5426551625570927,2.1377195147548065,1.8851096170287776,1.2166089853403548 +26,0.2834639842929726,-0.1685608994513771,0.2796272273279739,-0.7267493529675983,-0.0317367500850804,0.2797291923823009,1.22666593451069,0.4509214737945954,0.0286836361755821,0.8824781250037318,2.608395207197641,1.3515176292593474,2.3676411352726348,2.2054301780849324,2.41359064257243 +27,0.0973518179834588,0.627207651096977,1.1863792578907555,0.2890992919058788,0.1596584290155251,1.043863728080862,0.2577453753649159,0.9721736461738462,0.918363259276385,0.0627469568111569,-0.2707731498423658,0.3473959549961096,0.5237001919144207,-0.905561937627122,-0.5395182701079989 +28,0.2845817750816183,0.1281944558572799,-0.1567167730680698,-0.346562851787298,-0.3131557486045476,0.8284978968812677,1.7966192045964398,1.252160527337558,0.6828029559855183,1.3909744646312114,2.2693327382841804,1.7334005020321797,1.3368004922887282,1.822016173184383,0.8209400486441972 +29,0.2711682856178695,0.0721590870895,0.2828714206394686,-0.1564696011971481,-0.0200109584801025,0.7746564390813695,-1.0026655704818563,0.8232444540654887,0.6089706221182318,-0.3010909413705743,0.171343953149858,-0.1117270465662024,0.4719297608044463,-0.2341829838983112,-0.2635474991350485 +30,0.4617516150819662,0.6653382274774189,-0.0350595238870167,-0.0571852282774517,0.2893986393544729,1.4248955832801449,1.3569409676731468,1.5857619176602782,1.387725953146992,0.7334360944232642,1.0905658435437622,1.6364905267807826,1.068812378307683,0.8788500767412585,0.7688491802276164 +31,0.4841074308548806,0.3367521301468276,-0.2199785426422174,0.2648835911937579,0.7081228750548095,0.1140631683826125,0.3977910360145572,0.361563958529581,0.0142687519443501,1.3734401080923335,2.056225933244763,2.0313267130773163,0.608969137272026,3.009467141948111,3.1173715243709177 diff --git a/.ci/test_data/phe_sgd/bob/s_model.tar.gz b/.ci/test_data/phe_sgd/bob/s_model.tar.gz new file mode 100644 index 0000000..d2340a9 Binary files /dev/null and b/.ci/test_data/phe_sgd/bob/s_model.tar.gz differ diff --git a/.ci/test_data/phe_sgd/predict.csv b/.ci/test_data/phe_sgd/predict.csv new file mode 100644 index 0000000..a6b94c4 --- /dev/null +++ b/.ci/test_data/phe_sgd/predict.csv @@ -0,0 +1,33 @@ +id,pred,y +0,0.36531496,0.0 +1,0.36864355,0.0 +2,0.3082614,0.0 +3,0.40345556,0.0 +4,0.333171,0.0 +5,0.5010476,0.0 +6,0.41521746,0.0 +7,0.466192,0.0 +8,0.46786073,0.0 +9,0.40951777,0.0 +10,0.46939817,0.0 +11,0.425126,0.0 +12,0.15717828,0.0 +13,0.37976164,0.0 +14,0.4048596,0.0 +15,0.416929,0.0 +16,0.48516732,0.0 +17,0.42658806,0.0 +18,0.3497193,0.0 +19,0.50896066,1.0 +20,0.5335696,1.0 +21,0.58377063,1.0 +22,0.37492034,0.0 +23,0.3530418,0.0 +24,0.42830193,0.0 +25,0.3619516,0.0 +26,0.44949424,0.0 +27,0.34697467,0.0 +28,0.42690903,0.0 +29,0.4074396,0.0 +30,0.3484857,0.0 +31,0.50262034,0.0 diff --git a/.ci/test_data/phe_sgd_no_feature/alice/alice.csv b/.ci/test_data/phe_sgd_no_feature/alice/alice.csv new file mode 100644 index 0000000..3e4ba32 --- /dev/null +++ b/.ci/test_data/phe_sgd_no_feature/alice/alice.csv @@ -0,0 +1,33 @@ +id,y +0,0 +1,0 +2,0 +3,0 +4,0 +5,0 +6,0 +7,0 +8,0 +9,0 +10,0 +11,0 +12,0 +13,0 +14,0 +15,0 +16,0 +17,0 +18,0 +19,1 +20,1 +21,1 +22,0 +23,0 +24,0 +25,0 +26,0 +27,0 +28,0 +29,0 +30,0 +31,0 diff --git a/.ci/test_data/phe_sgd_no_feature/alice/s_model.tar.gz b/.ci/test_data/phe_sgd_no_feature/alice/s_model.tar.gz new file mode 100644 index 0000000..cade321 Binary files /dev/null and b/.ci/test_data/phe_sgd_no_feature/alice/s_model.tar.gz differ diff --git a/.ci/test_data/phe_sgd_no_feature/bob/bob.csv b/.ci/test_data/phe_sgd_no_feature/bob/bob.csv new file mode 100644 index 0000000..1bf9fc2 --- /dev/null +++ b/.ci/test_data/phe_sgd_no_feature/bob/bob.csv @@ -0,0 +1,33 @@ +id,a0,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13,a14,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14 +0,1.0970639814699807,-2.073335014697593,1.2699336881399383,0.9843749048031144,1.568466329243428,3.2835146709868264,2.652873983743168,2.532475216403245,2.2175150059646405,2.255746885296269,2.4897339267376197,-0.5652650590684639,2.833030865855184,2.4875775569611043,-0.2140016466689538,1.3168615683959484,0.72402615808036,0.6608199414286064,1.1487566671861758,0.907083080997336,1.886689625179276,-1.3592934737640827,2.303600623622561,2.0012374893299207,1.307686271071539,2.616665023512603,2.109526346572256,2.296076127561788,2.750622244124955,1.9370146123781784 +1,1.8298206075464456,-0.3536324082438112,1.6859547105508974,1.9087082542365936,-0.8269624468508425,-0.4870716725758942,-0.0238458551987692,0.5481441558908369,0.0013923632994608,-0.8686524574634664,0.4992546006760562,-0.8762436030602548,0.263326965842778,0.7424019483418791,-0.6053508469797809,-0.6929262695890712,-0.4407800577847933,0.2601620674590054,-0.8054503802819919,-0.0994437403202747,1.805927438479428,-0.3692032217294088,1.535125992343437,1.8904889885289908,-0.3756119566608087,-0.4304442186927949,-0.1467489683154696,1.087084295170027,-0.2438896675666794,0.2811899865404747 +2,1.5798881149312178,0.4561869517641946,1.5665031298586416,1.5588836327586924,0.9422104400684552,1.05292554434161,1.3634784515699176,2.037230755700812,0.939684816618985,-0.3980079103689868,1.2286759457296228,-0.7800833765050336,0.8509283007136554,1.181336055653447,-0.2970050119818975,0.8149735042940163,0.2130764345824376,1.42482746628562,0.2370355353748186,0.2935594041175298,1.5118702458799815,-0.0239743838488975,1.3474752102869063,1.456284548880901,0.5274074050914401,1.0829321669453351,0.8549739441841201,1.9550003461313663,1.1522550000669671,0.2013912093916699 +3,-0.7689092872596208,0.2537321117621929,-0.5926871666544732,-0.7644637923250287,3.283553480279431,3.402908991274548,1.9158971800569968,1.451707356849496,2.867382930831859,4.9109192850190375,0.3263734407153149,-0.1104090440232948,0.286593404544489,-0.2883781482770153,0.6897016600113287,2.7442804054965437,0.8195183841461625,1.115007005037871,4.732680372580089,2.0475108774169515,-0.2814644639166429,0.1339840938605815,-0.2499393042673343,-0.5500212283270541,3.3942746991980925,3.8933974345995,1.9895882583898328,2.175786008218023,6.046041349536007,4.935010337204809 +4,1.7502966326234184,-1.1518164326195182,1.7765731510760563,1.826229278440991,0.2803718299176319,0.5393404523102987,1.3710114342311053,1.4284927727540695,-0.0095604668949302,-0.562449981040552,1.2705427819622863,-0.7902437023297363,1.2731894116191806,1.1903567566057145,1.483067159789666,-0.0485198799348087,0.8284707803398315,1.1442047448413235,-0.3610922722145709,0.4993281342177852,1.2985752399803827,-1.4667703761231092,1.3385394587604047,1.2207242455900345,0.2205561656610641,-0.3133945108502493,0.6131787584083571,0.7292592566157908,-0.8683529835650433,-0.3970996192243676 +5,-0.4763746652213425,-0.8353353034209873,-0.3871480674633165,-0.5056504544836544,2.237421483589421,1.2443354863901803,0.8663015959315467,0.8246556464496959,1.005401797785333,1.8900050384577884,-0.2550702935159049,-0.5926616519172156,-0.3213041853640514,-0.2892582166626024,0.1563467021771524,0.4455436486466032,0.1600251978792141,-0.0691235536577115,0.1341188073483042,0.4868458399286153,-0.1654982471168611,-0.3138363326353646,-0.1150094562171625,-0.2443202078622653,2.04851283483916,1.7216164423470517,1.263243196357085,0.905887786285116,1.7540693875058049,2.2418016084326413 +6,1.170907672469935,0.1606494267038018,1.13812504737607,1.095294906735132,-0.1231362259485147,0.0882952423344605,0.3000723992322905,0.646935108208041,-0.0643246178668869,-0.7623321531499545,0.1498830707345162,-0.8049398878976097,0.1554102927156918,0.2986274649095823,-0.909029826096615,-0.6515680104091791,-0.3101413874031051,-0.2280890259209542,-0.8296660809941128,-0.6112178061762416,1.3689833001802505,0.3228828919461442,1.368325297182076,1.2752195396349364,0.5186402268220005,0.0212149800474625,0.5095522502187443,1.196715796344091,0.2624756637998691,-0.0147304787196769 +7,-0.1185167780677197,0.3584501324528832,-0.072866839641968,-0.2189649110285938,1.6040490502192788,1.140102349631058,0.061025749450609,0.2819502582632787,1.403354628181551,1.6603531811406034,0.6436230014783456,0.2905609572730053,0.4900509855317937,0.2337224214725335,0.588030871174189,0.2689327040405782,-0.2325539537246406,0.4353485062797216,-0.6880042318282048,0.6116687828203147,0.1637629755825187,0.4010479118436185,0.0994485804188717,0.0288594274466947,1.4479611233825669,0.7247855065358073,-0.0210538519002913,0.6241957346573126,0.4776404851153671,1.7264345060132755 +8,-0.3201668573368246,0.5888297779724025,-0.1840803802864819,-0.3842072728811636,2.20183876261357,1.6840098087195687,1.2190962838971586,1.150691583078798,1.965599911493639,1.5724617295747656,-0.3568500160815189,-0.3898180042026168,-0.2277433999465317,-0.3524031233284771,-0.4366773415647225,0.533290225555293,0.1205683405811914,0.0752430487038047,0.1074815365649713,-0.01736319908543,-0.161356596516869,0.8228133317070733,-0.0316091086364824,-0.2483634070978549,1.6627569909838305,1.8183096792604587,1.2800345287026242,1.39161624287576,2.3898571677839318,1.2886495480441378 +9,-0.4735345232598054,1.105438680046475,-0.329481787129124,-0.5090633776200244,1.5826994176337676,2.563358453378346,1.738872087519092,0.941760326219959,0.7972980240918982,2.783095594691288,-0.3882501432560169,0.6933453024665736,-0.4094196340641493,-0.3607637729915545,0.0360084898158162,2.6095866154647336,1.5098476017468596,0.4093949597877636,-0.3211363660395712,2.3773460477247146,-0.2441896085167131,2.443109056665133,-0.2862780271417735,-0.2974091717382667,2.3202953611917776,5.112877271198196,3.9954328451526617,1.62001520365506,2.370443800447194,6.846856039728261 +10,0.537556015047254,0.9192733099296918,0.4420106633418918,0.4064532537111665,-1.0176858312814026,-0.7135418515343508,-0.700684347306461,-0.404685551314783,-1.0354755617695854,-0.8261243357380614,-0.0926558426133296,-0.0541643832079764,-0.1980415633060492,0.0038045557379028,-1.004033677960828,-0.9059213043655148,-0.6924418618957103,-0.6821138798646442,-0.7194846427539621,-0.2847868979473433,0.604848764481689,1.3357712747842485,0.4926216475849352,0.4736113433615398,-0.6254765373398288,-0.6308282294015006,-0.6058719698777817,-0.2262097293109468,0.0764308934894756,0.0318188079504589 +11,0.4693926079703736,-0.3257076027262936,0.4790818435567294,0.3586723298019898,0.0526424156721877,0.4711151264316007,0.1348489795302461,0.442130888521722,0.1109206652433743,-0.2803467735953673,0.3631873829198987,-0.4208432848459051,0.3455020472147793,0.3041278923195017,-0.4233434676188681,0.8457127509817739,-0.1320881742179109,0.1660804614256576,-0.0559744432418457,0.1320460810425129,0.8595602763812097,0.2610022511939773,0.870901795540163,0.7355403373192945,0.3169951266248968,1.9506267402998576,0.596387426062818,1.0109513082435937,1.4418377295066225,1.1556515861294627 +12,1.4322007329313104,1.2822957816574192,1.6653596104315431,1.3313554236673744,0.0739920482576977,2.6808576257249923,1.4777286885979273,1.6219476402159576,2.1371942512057704,2.155096997212811,1.986249128939636,4.265788436187907,4.061201810939133,1.669113958365099,-1.3007123732560884,3.2131936413334268,1.8901586548630944,4.720927870764283,2.941929304918738,3.4213197519098357,0.971384842580999,0.6941667364591471,1.323646539549569,0.7935514567864481,-1.2567133727394586,0.8653723838901689,0.4399881590729394,0.9454769394868612,0.4452848728874726,1.0171120424683429 +13,0.4892736017011304,1.084495075908337,0.4832008635806006,0.3635073042451804,-0.8789132194755842,-0.078477776480135,0.1328401841539294,0.1217696280048353,0.1291753822340265,-1.3350441923854053,-0.0067566441359675,-0.2519278680102239,0.0182868135587977,-0.08266216314603,0.90937723326928,0.3231455572898961,0.6172605442151224,1.317769087006293,1.1221193964028424,-0.2999169516311856,0.1182048189826047,0.3228828919461442,0.1411487542092119,-0.0071777831313855,-0.8446559940758109,-0.3935481151337316,-0.1918456894720605,-0.041206571079714,-0.1484406114943913,-1.1679336401548384 +14,-0.1128364941446465,0.7726680809627255,0.0671798411696414,-0.2178272699831373,1.1912894868994104,2.3681582154476257,1.5568250065403952,0.8081474977596147,0.939684816618985,1.9878197184262187,-0.6968375999709101,-0.0868225733588063,-0.3985289606293058,-0.4648318595872295,-0.2040012412095631,1.8936416153371445,0.766467147442939,0.7273259043142488,-0.1128813399153303,1.625760630979847,-0.2566145603166899,1.031253384767004,0.0458340712598631,-0.321492575880691,1.434810355978408,3.296698380489132,2.025089932491831,1.6169698841780027,1.1247527296732565,3.2780773950178173 +15,0.1172150047398245,1.9199121743074004,0.1961051679168003,0.0111229904150014,1.2482218404607712,1.0453449525773104,0.9428869196536188,0.6376492745698704,1.794005571781509,1.1301692636305567,-0.1269433378038733,-0.3335733433872988,0.0064060789026048,-0.1713290529939307,-0.4780123507968712,0.9457550265655666,0.5144737731026268,-0.1453620964778381,-0.2388029836183596,0.6320943552935019,0.2465959875823632,1.86501359700673,0.5015573991114366,0.11007499470071,1.553167262615839,2.5664099859062928,2.0649093777683944,0.8617306538677845,2.131012269960775,2.779335037837786 +16,0.1569769922013383,0.1955554336006982,0.1141366694417695,0.0842164275855874,0.1643721595363581,-0.612909495863271,-0.1864327309693982,0.0946859465601709,-0.8237208446780198,-0.5071634227975256,0.2437225312560043,0.0419958433472446,0.1628357518758122,0.1113929158759261,-0.4410108505971252,-0.7745249971602093,-0.3950233661282629,-0.1145422600186381,-0.7800238945342648,-0.646773432333271,0.5799988608817356,0.8472399004250337,0.4807073122162669,0.4525163908280294,0.6150791877858333,-0.427263520110117,0.0921677033439151,0.7048967007993319,0.2074711230124482,-0.0989625212656381 +17,0.5687975766241574,0.3235441255559867,0.664437744630919,0.4092973563248079,1.468834710511046,1.8545731234163136,1.0470931798000451,1.389801799261692,1.2865244394413775,1.5256807956768208,0.5920112981915273,-0.2609995874965657,0.4890609243104445,0.3045679265122953,-0.0049931725676859,-0.0261640641618941,-0.0004547931480373,0.1904119112618683,-0.4422148696001764,0.1312895783583206,0.971384842580999,0.9449461752968767,0.8798375470666647,0.7636669406973083,2.0397456565697203,1.0752984903469085,0.9893046029484328,1.411410819476633,1.3027085969266765,1.6765602702952729 +18,1.6139698184696574,0.6656229931455752,1.5665031298586416,1.7209974817362566,0.1387526004337456,-0.0310990779532612,0.7420073820219539,1.188092857454763,-0.8383246182705409,-1.254240761107136,1.274151991982344,-0.3626028457435921,1.484567482377281,1.585507461734324,-0.1823336960475497,-0.3659724639101956,0.0668539634191777,0.55376156214928,-0.8454062864569916,-0.680059550437724,2.28842973337852,0.8472399004250337,2.369129468150238,2.667486406846625,0.8254914662523765,0.3863591773388819,1.2713989863534898,1.8910486371131627,-0.2147696165615749,-0.4320115842269701 +19,-0.1667991914138439,-1.1471622983665986,-0.1857279882960305,-0.251956501346835,0.1017465706188614,-0.4368502521374081,-0.2782095697248652,-0.0286092889688724,0.2679112313629833,-0.7283096557696301,-0.4882252608116023,-0.7769989918796776,-0.40001405246133,-0.3691244226546321,0.4736929020884875,-0.6079741696519958,-0.2660425468935505,0.219609651065321,-0.089876424238815,-0.5654493937826185,-0.240047957916721,-1.0450049562596546,-0.225217058377347,-0.2977607542804917,0.509873048552561,-0.4896052123306031,-0.1592225294864416,0.2161229247316304,0.1233465312199229,-0.6292918944004043 +20,-0.2974457216445311,-0.833008236294527,-0.2611060547328675,-0.3836384523584352,0.792763011969882,0.4294218717279516,-0.5413617640223466,-0.4596267336739594,0.5672885900096801,0.7530865843319685,-0.7939253495104497,-0.8512056572779519,-0.7341597146667573,-0.5647196213513657,-0.9813660922528752,-0.3631779869385813,-0.4944944349468072,-0.86070672166243,-0.4555335049918428,-0.5181679760206112,-0.3663683012164831,-0.8447070927723773,-0.332743935079581,-0.4396243100683495,-0.0512263606915536,0.1484429233545774,-0.3990987058512859,-0.6361097309228636,0.4582271177786307,-0.1172497410289058 +21,-1.31308048709005,-1.5939591866468776,-1.3028062187698126,-1.0835721055756025,0.4298192580162052,-0.7470859700913773,-0.7437478981862491,-0.7263365109480825,0.012345193493852,0.8863413657382362,-0.4615171066631787,-0.4355394704137786,-0.4737736134518613,-0.5420578604224978,0.8550416969399234,-0.623623240693036,-0.3993337791103998,0.3915518965745426,-0.0329695275653308,-0.3127774972624515,-1.250610704314819,-1.6312426054907117,-1.25491349261453,-0.994421561699674,0.0013767089250812,-0.8871925351653367,-0.8804342413449825,-0.7969025993114908,-0.7292238509850971,-0.3444545926331425 +22,0.3444263616627596,-1.1704329696311964,0.4337726232941499,0.1408140695970527,0.7785299235795423,2.0687248407577834,1.492794653920302,1.2546413318616525,2.5899112325739453,1.0663770810424489,0.1213703115760642,-0.9203321597638752,0.2563965372933319,0.1006120781524841,-0.0839963756968734,1.5538332155888428,1.0798010142213534,1.73951421750061,1.9587718560066243,0.2266089165665275,0.3729163308821253,-1.074316838721208,0.5313432375331081,0.1763483039101552,0.2906935918165793,2.170094942504631,1.7190079314502889,1.898661935805806,2.8573957644770065,0.8597311208693111 +23,1.99738898327711,0.8727319674004963,1.8630725715773444,2.130548258100629,-0.1480441306316104,-0.040574817658636,0.2624074859263535,0.9647169704921028,-0.1555982028201481,-1.420100435836214,1.034139525648538,-0.163025017044076,0.711329668503388,1.1804559872678595,-0.7710242307570218,-0.7203121439108915,-0.4888577410470897,-0.2297111225767015,-1.1759506011774437,-0.6838420638586846,2.671532413877799,1.6142341581690005,2.404872474256244,3.048953465160939,0.3389130722984951,0.0364823332443163,0.2077880203517701,1.313960596210798,-0.1274094635462595,-0.4813316617703289 +24,0.7164849586240649,0.4864388244081716,0.7426991250844659,0.7102034128480763,1.120124044947709,0.7838145367089672,0.7997602490910576,1.1034885954180975,0.6695150051573325,0.0712190326679779,1.449559598957125,-0.5717966970986299,1.2816049320006508,1.3698907072654836,-0.3310063905438262,-0.3721203132477471,-0.1486666856876682,-0.0804782302479432,-0.7097983624691137,-0.3759454713924933,2.1103387575788557,0.957973678613122,2.077228251617857,2.3457883807105917,2.1098830827252346,0.6586269760161075,0.9466066435554904,1.4449093337242631,1.1522550000669671,0.6480426981551198 +25,0.8556519147393631,-0.6724406045688023,0.9898403265167186,0.7332406440185723,1.5826994176337676,2.3359407004493518,1.683630214670384,2.3519173401054827,4.484750856203646,1.60648422695509,2.312882635754815,-0.4369909455315933,2.183055674039293,1.5635057520946465,0.3293537166246132,0.6992821576691837,0.1799194116429228,1.97471823258398,0.3072610674399696,1.380275509959505,1.238521306280496,-0.6965192425500825,1.3444966264447396,1.020322196521685,0.9701499076981248,0.8946348108508052,0.5426551625570927,2.1377195147548065,1.8851096170287776,1.2166089853403548 +26,0.1285755725859715,0.5213448313050689,0.224114504079122,-0.0286944461759792,0.6433155838713087,1.5627203404907717,0.6742105380712667,1.0036658838077632,1.607807458476857,0.913275842830992,-0.5438070951204835,-0.4239276694712613,-0.3742724607062452,-0.4243487138502227,-0.8630279609834172,0.2834639842929726,-0.1685608994513771,0.2796272273279739,-0.7267493529675983,-0.0317367500850804,0.2797291923823009,1.22666593451069,0.4509214737945954,0.0286836361755821,0.8824781250037318,2.608395207197641,1.3515176292593474,2.3676411352726348,2.2054301780849324,2.41359064257243 +27,1.2731527830852554,0.2234802391182159,1.241100547972842,1.2488764478717715,-0.1395042775974066,0.0428116917486617,0.7558178502341308,0.7323131897145546,-0.4184661274855406,-0.8232891276230346,1.6159441808817636,1.146931276783655,1.369225350090074,1.170555217930005,1.2363904917913593,0.0973518179834588,0.627207651096977,1.1863792578907555,0.2890992919058788,0.1596584290155251,1.043863728080862,0.2577453753649159,0.9721736461738462,0.918363259276385,0.0627469568111569,-0.2707731498423658,0.3473959549961096,0.5237001919144207,-0.905561937627122,-0.5395182701079989 +28,0.333065793816613,1.3916679366010287,0.4296536032702793,0.220448942779014,0.8425788213360734,1.238650042566955,0.998128792502327,0.9954118094627226,0.4175999106863312,0.3689158847458118,0.1220921535800756,-0.3716745652299335,0.3128300269102487,0.0695896675605388,-0.60268407219061,0.2845817750816183,0.1281944558572799,-0.1567167730680698,-0.346562851787298,-0.3131557486045476,0.8284978968812677,1.7966192045964398,1.252160527337558,0.6828029559855183,1.3909744646312114,2.2693327382841804,1.7334005020321797,1.3368004922887282,1.822016173184383,0.8209400486441972 +29,0.9777780190854404,-0.986594666640873,0.9486501262780098,0.8538305948369708,0.1501390711460184,0.2152701543864822,0.1249305523596827,0.7895758304832735,-0.2651265047640616,-0.1853673017419629,0.7042577298153069,-0.7154927337622812,0.8855804434608852,0.4568197572188642,-0.4713454138239439,0.2711682856178695,0.0721590870895,0.2828714206394686,-0.1564696011971481,-0.0200109584801025,0.7746564390813695,-1.0026655704818563,0.8232444540654887,0.6089706221182318,-0.3010909413705743,0.171343953149858,-0.1117270465662024,0.4719297608044463,-0.2341829838983112,-0.2635474991350485 +30,1.2788330670083286,1.354434862577672,1.352314088617356,1.2318118321899227,0.7144810258230102,1.598728151371196,1.7966249545881956,1.9469518175519296,1.355892364005855,-0.1173223069813161,1.535819718436493,0.4520375641298862,1.340513574670941,1.4226948104007096,-0.2643370208145543,0.4617516150819662,0.6653382274774189,-0.0350595238870167,-0.0571852282774517,0.2893986393544729,1.4248955832801449,1.3569409676731468,1.5857619176602782,1.387725953146992,0.7334360944232642,1.0905658435437622,1.6364905267807826,1.068812378307683,0.8788500767412585,0.7688491802276164 +31,-0.6496233248750801,-0.1372151654830511,-0.5782705965709247,-0.6094601998815683,1.0347255146056669,0.8956282652323896,0.4143226362603002,0.074824580167417,1.786703684985248,2.153679393155298,0.2790927894525651,-0.3390163750791036,0.3014443228647304,0.0145853934613448,-0.4966797743210672,0.4841074308548806,0.3367521301468276,-0.2199785426422174,0.2648835911937579,0.7081228750548095,0.1140631683826125,0.3977910360145572,0.361563958529581,0.0142687519443501,1.3734401080923335,2.056225933244763,2.0313267130773163,0.608969137272026,3.009467141948111,3.1173715243709177 diff --git a/.ci/test_data/phe_sgd_no_feature/bob/s_model.tar.gz b/.ci/test_data/phe_sgd_no_feature/bob/s_model.tar.gz new file mode 100644 index 0000000..a82a514 Binary files /dev/null and b/.ci/test_data/phe_sgd_no_feature/bob/s_model.tar.gz differ diff --git a/.ci/test_data/phe_sgd_no_feature/predict.csv b/.ci/test_data/phe_sgd_no_feature/predict.csv new file mode 100644 index 0000000..52617ed --- /dev/null +++ b/.ci/test_data/phe_sgd_no_feature/predict.csv @@ -0,0 +1,33 @@ +id,pred,y +0,0.36531496,0.0 +1,0.36864352,0.0 +2,0.3082614,0.0 +3,0.4034556,0.0 +4,0.33317098,0.0 +5,0.50104755,0.0 +6,0.41521746,0.0 +7,0.466192,0.0 +8,0.46786073,0.0 +9,0.40951777,0.0 +10,0.46939817,0.0 +11,0.425126,0.0 +12,0.1571783,0.0 +13,0.37976164,0.0 +14,0.4048596,0.0 +15,0.416929,0.0 +16,0.4851673,0.0 +17,0.42658806,0.0 +18,0.34971926,0.0 +19,0.50896066,1.0 +20,0.5335696,1.0 +21,0.58377063,1.0 +22,0.37492037,0.0 +23,0.3530418,0.0 +24,0.42830193,0.0 +25,0.36195156,0.0 +26,0.44949424,0.0 +27,0.34697467,0.0 +28,0.42690903,0.0 +29,0.4074396,0.0 +30,0.3484857,0.0 +31,0.5026204,0.0 diff --git a/.circleci/continue-config.yml b/.circleci/continue-config.yml index de1afc4..973187e 100644 --- a/.circleci/continue-config.yml +++ b/.circleci/continue-config.yml @@ -51,13 +51,13 @@ commands: name: "build" command: | sh ./build_wheel_entrypoint.sh - bazel build <> -c opt + bazel build <> -c opt --jobs 16 - run: name: "test" command: | set +e declare -i test_status - bazel test <> -c opt --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} + bazel test <> -c opt --jobs 16 --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} sh ../devtools/rename-junit-xml.sh find bazel-testlogs/ -type f -name "test.log" -print0 | xargs -0 tar -cvzf test_logs.tar.gz @@ -109,6 +109,15 @@ jobs: exit ${test_status} - store_artifacts: path: accuracy_test.py.log + - run: + name: "inferencer test" + command: | + set +e + declare -i test_status + python3 .ci/inferencer_test.py 2>&1 | tee accuracy_test.py.log; test_status=${PIPESTATUS[0]} + exit ${test_status} + - store_artifacts: + path: inferencer_test.py.log macOS_lib_ut: macos: diff --git a/WORKSPACE b/WORKSPACE index 02fd983..4efd167 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,10 +18,20 @@ load("//bazel:repositories.bzl", "sf_serving_deps") sf_serving_deps() +load("@com_alipay_sf_heu//third_party/bazel_cpp:repositories.bzl", "heu_cpp_deps") + +heu_cpp_deps() + load("@yacl//bazel:repositories.bzl", "yacl_deps") yacl_deps() +load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies") + +rules_cuda_dependencies() + +register_detected_cuda_toolchains() + load("@dataproxy//dataproxy_sdk/bazel:repositories.bzl", "dataproxy_deps") dataproxy_deps() diff --git a/bazel/curl.BUILD b/bazel/curl.BUILD deleted file mode 100644 index 6e9fb6a..0000000 --- a/bazel/curl.BUILD +++ /dev/null @@ -1,671 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//visibility:public"], -) - -CURL_WIN_COPTS = [ - "/Iexternal/curl/lib", - "/DBUILDING_LIBCURL", - "/DHAVE_CONFIG_H", - "/DCURL_DISABLE_FTP", - "/DCURL_DISABLE_NTLM", - "/DCURL_DISABLE_PROXY", - "/DHAVE_LIBZ", - "/DHAVE_ZLIB_H", - # Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect - # detection of what OS releases we can build on with VC 2012. This - # may not be needed (or may have to change) if the WINVER setting - # changes in //third_party/msvc/vc_12_0/CROSSTOOL. - "/D_USING_V110_SDK71_", -] - -CURL_WIN_SRCS = [ - "lib/asyn-thread.c", - "lib/inet_ntop.c", - "lib/system_win32.c", - "lib/vtls/schannel.c", - "lib/idn_win32.c", -] - -cc_library( - name = "curl", - srcs = [ - "include/curl_config.h", - "lib/amigaos.h", - "lib/arpa_telnet.h", - "lib/asyn.h", - "lib/asyn-ares.c", - "lib/base64.c", - "lib/config-win32.h", - "lib/conncache.c", - "lib/conncache.h", - "lib/connect.c", - "lib/connect.h", - "lib/content_encoding.c", - "lib/content_encoding.h", - "lib/cookie.c", - "lib/cookie.h", - "lib/curl_addrinfo.c", - "lib/curl_addrinfo.h", - "lib/curl_base64.h", - "lib/curl_ctype.c", - "lib/curl_ctype.h", - "lib/curl_des.h", - "lib/curl_endian.h", - "lib/curl_fnmatch.c", - "lib/curl_fnmatch.h", - "lib/curl_gethostname.c", - "lib/curl_gethostname.h", - "lib/curl_gssapi.h", - "lib/curl_hmac.h", - "lib/curl_ldap.h", - "lib/curl_md4.h", - "lib/curl_md5.h", - "lib/curl_memory.h", - "lib/curl_memrchr.c", - "lib/curl_memrchr.h", - "lib/curl_multibyte.c", - "lib/curl_multibyte.h", - "lib/curl_ntlm_core.h", - "lib/curl_ntlm_wb.h", - "lib/curl_printf.h", - "lib/curl_rtmp.c", - "lib/curl_rtmp.h", - "lib/curl_sasl.c", - "lib/curl_sasl.h", - "lib/curl_sec.h", - "lib/curl_setup.h", - "lib/curl_setup_once.h", - "lib/curl_sha256.h", - "lib/curl_sspi.c", - "lib/curl_sspi.h", - "lib/curl_threads.c", - "lib/curl_threads.h", - "lib/curlx.h", - "lib/dict.h", - "lib/dotdot.c", - "lib/dotdot.h", - "lib/easy.c", - "lib/easyif.h", - "lib/escape.c", - "lib/escape.h", - "lib/file.h", - "lib/fileinfo.c", - "lib/fileinfo.h", - "lib/formdata.c", - "lib/formdata.h", - "lib/ftp.h", - "lib/ftplistparser.h", - "lib/getenv.c", - "lib/getinfo.c", - "lib/getinfo.h", - "lib/gopher.h", - "lib/hash.c", - "lib/hash.h", - "lib/hmac.c", - "lib/hostasyn.c", - "lib/hostcheck.c", - "lib/hostcheck.h", - "lib/hostip.c", - "lib/hostip.h", - "lib/hostip4.c", - "lib/hostip6.c", - "lib/hostsyn.c", - "lib/http.c", - "lib/http.h", - "lib/http2.c", - "lib/http2.h", - "lib/http_chunks.c", - "lib/http_chunks.h", - "lib/http_digest.c", - "lib/http_digest.h", - "lib/http_negotiate.h", - "lib/http_ntlm.h", - "lib/http_proxy.c", - "lib/http_proxy.h", - "lib/if2ip.c", - "lib/if2ip.h", - "lib/imap.h", - "lib/inet_ntop.h", - "lib/inet_pton.c", - "lib/inet_pton.h", - "lib/krb5.c", - "lib/llist.c", - "lib/llist.h", - "lib/md4.c", - "lib/md5.c", - "lib/memdebug.c", - "lib/memdebug.h", - "lib/mime.c", - "lib/mime.h", - "lib/mprintf.c", - "lib/multi.c", - "lib/multihandle.h", - "lib/multiif.h", - "lib/netrc.c", - "lib/netrc.h", - "lib/non-ascii.h", - "lib/nonblock.c", - "lib/nonblock.h", - "lib/nwlib.c", - "lib/nwos.c", - "lib/parsedate.c", - "lib/parsedate.h", - "lib/pingpong.h", - "lib/pipeline.c", - "lib/pipeline.h", - "lib/pop3.h", - "lib/progress.c", - "lib/progress.h", - "lib/rand.c", - "lib/rand.h", - "lib/rtsp.c", - "lib/rtsp.h", - "lib/security.c", - "lib/select.c", - "lib/select.h", - "lib/sendf.c", - "lib/sendf.h", - "lib/setopt.c", - "lib/setopt.h", - "lib/setup-os400.h", - "lib/setup-vms.h", - "lib/sha256.c", - "lib/share.c", - "lib/share.h", - "lib/sigpipe.h", - "lib/slist.c", - "lib/slist.h", - "lib/smb.h", - "lib/smtp.h", - "lib/sockaddr.h", - "lib/socks.c", - "lib/socks.h", - "lib/speedcheck.c", - "lib/speedcheck.h", - "lib/splay.c", - "lib/splay.h", - "lib/ssh.h", - "lib/strcase.c", - "lib/strcase.h", - "lib/strdup.c", - "lib/strdup.h", - "lib/strerror.c", - "lib/strerror.h", - "lib/strtok.c", - "lib/strtok.h", - "lib/strtoofft.c", - "lib/strtoofft.h", - "lib/system_win32.h", - "lib/telnet.h", - "lib/tftp.h", - "lib/timeval.c", - "lib/timeval.h", - "lib/transfer.c", - "lib/transfer.h", - "lib/url.c", - "lib/url.h", - "lib/urldata.h", - "lib/vauth/cleartext.c", - "lib/vauth/cram.c", - "lib/vauth/digest.c", - "lib/vauth/digest.h", - "lib/vauth/ntlm.h", - "lib/vauth/oauth2.c", - "lib/vauth/vauth.c", - "lib/vauth/vauth.h", - "lib/version.c", - "lib/vtls/axtls.h", - "lib/vtls/cyassl.h", - "lib/vtls/darwinssl.h", - "lib/vtls/gskit.h", - "lib/vtls/gtls.h", - "lib/vtls/mbedtls.h", - "lib/vtls/nssg.h", - "lib/vtls/openssl.h", - "lib/vtls/polarssl.h", - "lib/vtls/polarssl_threadlock.h", - "lib/vtls/schannel.h", - "lib/vtls/vtls.c", - "lib/vtls/vtls.h", - "lib/warnless.c", - "lib/warnless.h", - "lib/wildcard.c", - "lib/wildcard.h", - "lib/x509asn1.h", - ] + select({ - "//conditions:default": [ - "lib/vtls/openssl.c", - ], - }), - hdrs = [ - "include/curl/curl.h", - "include/curl/curlver.h", - "include/curl/easy.h", - "include/curl/mprintf.h", - "include/curl/multi.h", - "include/curl/stdcheaders.h", - "include/curl/system.h", - "include/curl/typecheck-gcc.h", - ], - copts = select({ - "//conditions:default": [ - "-Iexternal/curl/lib", - "-D_GNU_SOURCE", - "-DBUILDING_LIBCURL", - "-DHAVE_CONFIG_H", - "-DCURL_DISABLE_FTP", - "-DCURL_DISABLE_NTLM", # turning it off in configure is not enough - "-DHAVE_LIBZ", - "-DHAVE_ZLIB_H", - "-Wno-string-plus-int", - ], - }) + select({ - "//conditions:default": [ - "-DCURL_MAX_WRITE_SIZE=65536", - ], - }), - defines = ["CURL_STATICLIB"], - includes = [ - "include", - "lib", - ], - linkopts = select({ - "//conditions:default": [ - "-lrt", - ], - }), - visibility = ["//visibility:public"], - deps = [ - "@zlib//:zlib", - ] + select({ - "//conditions:default": [ - "@com_github_openssl_openssl//:openssl", - ], - }), -) - -CURL_BIN_WIN_COPTS = [ - "/Iexternal/curl/lib", - "/DHAVE_CONFIG_H", - "/DCURL_DISABLE_LIBCURL_OPTION", -] - -cc_binary( - name = "curl_bin", - srcs = [ - "lib/config-win32.h", - "src/slist_wc.c", - "src/slist_wc.h", - "src/tool_binmode.c", - "src/tool_binmode.h", - "src/tool_bname.c", - "src/tool_bname.h", - "src/tool_cb_dbg.c", - "src/tool_cb_dbg.h", - "src/tool_cb_hdr.c", - "src/tool_cb_hdr.h", - "src/tool_cb_prg.c", - "src/tool_cb_prg.h", - "src/tool_cb_rea.c", - "src/tool_cb_rea.h", - "src/tool_cb_see.c", - "src/tool_cb_see.h", - "src/tool_cb_wrt.c", - "src/tool_cb_wrt.h", - "src/tool_cfgable.c", - "src/tool_cfgable.h", - "src/tool_convert.c", - "src/tool_convert.h", - "src/tool_dirhie.c", - "src/tool_dirhie.h", - "src/tool_doswin.c", - "src/tool_doswin.h", - "src/tool_easysrc.c", - "src/tool_easysrc.h", - "src/tool_formparse.c", - "src/tool_formparse.h", - "src/tool_getparam.c", - "src/tool_getparam.h", - "src/tool_getpass.c", - "src/tool_getpass.h", - "src/tool_help.c", - "src/tool_help.h", - "src/tool_helpers.c", - "src/tool_helpers.h", - "src/tool_homedir.c", - "src/tool_homedir.h", - "src/tool_hugehelp.c", - "src/tool_hugehelp.h", - "src/tool_libinfo.c", - "src/tool_libinfo.h", - "src/tool_main.c", - "src/tool_main.h", - "src/tool_metalink.c", - "src/tool_metalink.h", - "src/tool_mfiles.c", - "src/tool_mfiles.h", - "src/tool_msgs.c", - "src/tool_msgs.h", - "src/tool_operate.c", - "src/tool_operate.h", - "src/tool_operhlp.c", - "src/tool_operhlp.h", - "src/tool_panykey.c", - "src/tool_panykey.h", - "src/tool_paramhlp.c", - "src/tool_paramhlp.h", - "src/tool_parsecfg.c", - "src/tool_parsecfg.h", - "src/tool_sdecls.h", - "src/tool_setopt.c", - "src/tool_setopt.h", - "src/tool_setup.h", - "src/tool_sleep.c", - "src/tool_sleep.h", - "src/tool_strdup.c", - "src/tool_strdup.h", - "src/tool_urlglob.c", - "src/tool_urlglob.h", - "src/tool_util.c", - "src/tool_util.h", - "src/tool_version.h", - "src/tool_vms.c", - "src/tool_vms.h", - "src/tool_writeenv.c", - "src/tool_writeenv.h", - "src/tool_writeout.c", - "src/tool_writeout.h", - "src/tool_xattr.c", - "src/tool_xattr.h", - ], - copts = select({ - "//conditions:default": [ - "-Iexternal/curl/lib", - "-D_GNU_SOURCE", - "-DHAVE_CONFIG_H", - "-DCURL_DISABLE_LIBCURL_OPTION", - "-Wno-string-plus-int", - ], - }), - deps = [":curl"], -) - -genrule( - name = "configure", - outs = ["include/curl_config.h"], - cmd = "\n".join([ - "cat <<'EOF' >$@", - "#ifndef EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_", - "#define EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_", - "", - "#if !defined(_WIN32) && !defined(__APPLE__)", - "# include ", - "# if defined(OPENSSL_IS_BORINGSSL)", - "# define HAVE_BORINGSSL 1", - "# endif", - "#endif", - "", - "#if defined(_WIN32)", - "# include \"lib/config-win32.h\"", - "# define BUILDING_LIBCURL 1", - "# define CURL_DISABLE_CRYPTO_AUTH 1", - "# define CURL_DISABLE_DICT 1", - "# define CURL_DISABLE_FILE 1", - "# define CURL_DISABLE_GOPHER 1", - "# define CURL_DISABLE_IMAP 1", - "# define CURL_DISABLE_LDAP 1", - "# define CURL_DISABLE_LDAPS 1", - "# define CURL_DISABLE_POP3 1", - "# define CURL_PULL_WS2TCPIP_H 1", - "# define CURL_DISABLE_SMTP 1", - "# define CURL_DISABLE_TELNET 1", - "# define CURL_DISABLE_TFTP 1", - "# define CURL_PULL_WS2TCPIP_H 1", - "# define USE_WINDOWS_SSPI 1", - "# define USE_WIN32_IDN 1", - "# define USE_SCHANNEL 1", - "# define WANT_IDN_PROTOTYPES 1", - "#elif defined(__APPLE__)", - "# define HAVE_FSETXATTR_6 1", - "# define HAVE_SETMODE 1", - "# define HAVE_SYS_FILIO_H 1", - "# define HAVE_SYS_SOCKIO_H 1", - "# define OS \"x86_64-apple-darwin15.5.0\"", - "# define USE_DARWINSSL 1", - "#else", - "# define CURL_CA_BUNDLE \"/etc/ssl/certs/ca-certificates.crt\"", - "# define GETSERVBYPORT_R_ARGS 6", - "# define GETSERVBYPORT_R_BUFSIZE 4096", - "# define HAVE_BORINGSSL 1", - "# define HAVE_CLOCK_GETTIME_MONOTONIC 1", - "# define HAVE_CRYPTO_CLEANUP_ALL_EX_DATA 1", - "# define HAVE_FSETXATTR_5 1", - "# define HAVE_GETHOSTBYADDR_R 1", - "# define HAVE_GETHOSTBYADDR_R_8 1", - "# define HAVE_GETHOSTBYNAME_R 1", - "# define HAVE_GETHOSTBYNAME_R_6 1", - "# define HAVE_GETSERVBYPORT_R 1", - "# define HAVE_LIBSSL 1", - "# define HAVE_MALLOC_H 1", - "# define HAVE_MSG_NOSIGNAL 1", - "# define HAVE_OPENSSL_CRYPTO_H 1", - "# define HAVE_OPENSSL_ERR_H 1", - "# define HAVE_OPENSSL_PEM_H 1", - "# define HAVE_OPENSSL_PKCS12_H 1", - "# define HAVE_OPENSSL_RSA_H 1", - "# define HAVE_OPENSSL_SSL_H 1", - "# define HAVE_OPENSSL_X509_H 1", - "# define HAVE_RAND_add_egd 1", - "# define HAVE_RAND_STATUS 1", - "# define HAVE_SSL_GET_SHUTDOWN 1", - "# define HAVE_TERMIOS_H 1", - "# define OS \"x86_64-pc-linux-gnu\"", - "# define RANDOM_FILE \"/dev/urandom\"", - "# define USE_OPENSSL 1", - "#endif", - "", - "#if !defined(_WIN32)", - "# define CURL_DISABLE_DICT 1", - "# define CURL_DISABLE_FILE 1", - "# define CURL_DISABLE_GOPHER 1", - "# define CURL_DISABLE_IMAP 1", - "# define CURL_DISABLE_LDAP 1", - "# define CURL_DISABLE_LDAPS 1", - "# define CURL_DISABLE_POP3 1", - "# define CURL_DISABLE_SMTP 1", - "# define CURL_DISABLE_TELNET 1", - "# define CURL_DISABLE_TFTP 1", - "# define CURL_EXTERN_SYMBOL __attribute__ ((__visibility__ (\"default\")))", - "# define ENABLE_IPV6 1", - "# define GETHOSTNAME_TYPE_ARG2 size_t", - "# define GETNAMEINFO_QUAL_ARG1 const", - "# define GETNAMEINFO_TYPE_ARG1 struct sockaddr *", - "# define GETNAMEINFO_TYPE_ARG2 socklen_t", - "# define GETNAMEINFO_TYPE_ARG46 socklen_t", - "# define GETNAMEINFO_TYPE_ARG7 int", - "# define HAVE_ALARM 1", - "# define HAVE_ALLOCA_H 1", - "# define HAVE_ARPA_INET_H 1", - "# define HAVE_ARPA_TFTP_H 1", - "# define HAVE_ASSERT_H 1", - "# define HAVE_BASENAME 1", - "# define HAVE_BOOL_T 1", - "# define HAVE_CONNECT 1", - "# define HAVE_DLFCN_H 1", - "# define HAVE_ERRNO_H 1", - "# define HAVE_FCNTL 1", - "# define HAVE_FCNTL_H 1", - "# define HAVE_FCNTL_O_NONBLOCK 1", - "# define HAVE_FDOPEN 1", - "# define HAVE_FORK 1", - "# define HAVE_FREEADDRINFO 1", - "# define HAVE_FREEIFADDRS 1", - "# if !defined(__ANDROID__)", - "# define HAVE_FSETXATTR 1", - "# endif", - "# define HAVE_FTRUNCATE 1", - "# define HAVE_GAI_STRERROR 1", - "# define HAVE_GETADDRINFO 1", - "# define HAVE_GETADDRINFO_THREADSAFE 1", - "# define HAVE_GETEUID 1", - "# define HAVE_GETHOSTBYADDR 1", - "# define HAVE_GETHOSTBYNAME 1", - "# define HAVE_GETHOSTNAME 1", - "# if !defined(__ANDROID__)", - "# define HAVE_GETIFADDRS 1", - "# endif", - "# define HAVE_GETNAMEINFO 1", - "# define HAVE_GETPPID 1", - "# define HAVE_GETPROTOBYNAME 1", - "# define HAVE_GETPWUID 1", - "# if !defined(__ANDROID__)", - "# define HAVE_GETPWUID_R 1", - "# endif", - "# define HAVE_GETRLIMIT 1", - "# define HAVE_GETTIMEOFDAY 1", - "# define HAVE_GMTIME_R 1", - "# if !defined(__ANDROID__)", - "# define HAVE_IFADDRS_H 1", - "# endif", - "# define HAVE_IF_NAMETOINDEX 1", - "# define HAVE_INET_ADDR 1", - "# define HAVE_INET_NTOP 1", - "# define HAVE_INET_PTON 1", - "# define HAVE_INTTYPES_H 1", - "# define HAVE_IOCTL 1", - "# define HAVE_IOCTL_FIONBIO 1", - "# define HAVE_IOCTL_SIOCGIFADDR 1", - "# define HAVE_LIBGEN_H 1", - "# define HAVE_LIBZ 1", - "# define HAVE_LIMITS_H 1", - "# define HAVE_LL 1", - "# define HAVE_LOCALE_H 1", - "# define HAVE_LOCALTIME_R 1", - "# define HAVE_LONGLONG 1", - "# define HAVE_MEMORY_H 1", - "# define HAVE_NETDB_H 1", - "# define HAVE_NETINET_IN_H 1", - "# define HAVE_NETINET_TCP_H 1", - "# define HAVE_NET_IF_H 1", - "# define HAVE_PERROR 1", - "# define HAVE_PIPE 1", - "# define HAVE_POLL 1", - "# define HAVE_POLL_FINE 1", - "# define HAVE_POLL_H 1", - "# define HAVE_POSIX_STRERROR_R 1", - "# define HAVE_PWD_H 1", - "# define HAVE_RECV 1", - "# define HAVE_SELECT 1", - "# define HAVE_SEND 1", - "# define HAVE_SETJMP_H 1", - "# define HAVE_SETLOCALE 1", - "# define HAVE_SETRLIMIT 1", - "# define HAVE_SETSOCKOPT 1", - "# define HAVE_SGTTY_H 1", - "# define HAVE_SIGACTION 1", - "# define HAVE_SIGINTERRUPT 1", - "# define HAVE_SIGNAL 1", - "# define HAVE_SIGNAL_H 1", - "# define HAVE_SIGSETJMP 1", - "# define HAVE_SIG_ATOMIC_T 1", - "# define HAVE_SOCKADDR_IN6_SIN6_SCOPE_ID 1", - "# define HAVE_SOCKET 1", - "# define HAVE_SOCKETPAIR 1", - "# define HAVE_STDBOOL_H 1", - "# define HAVE_STDINT_H 1", - "# define HAVE_STDIO_H 1", - "# define HAVE_STDLIB_H 1", - "# define HAVE_STRCASECMP 1", - "# define HAVE_STRDUP 1", - "# define HAVE_STRERROR_R 1", - "# define HAVE_STRINGS_H 1", - "# define HAVE_STRING_H 1", - "# define HAVE_STRNCASECMP 1", - "# define HAVE_STRSTR 1", - "# define HAVE_STRTOK_R 1", - "# define HAVE_STRTOLL 1", - "# define HAVE_STRUCT_SOCKADDR_STORAGE 1", - "# define HAVE_STRUCT_TIMEVAL 1", - "# define HAVE_SYS_IOCTL_H 1", - "# define HAVE_SYS_PARAM_H 1", - "# define HAVE_SYS_POLL_H 1", - "# define HAVE_SYS_RESOURCE_H 1", - "# define HAVE_SYS_SELECT_H 1", - "# define HAVE_SYS_SOCKET_H 1", - "# define HAVE_SYS_STAT_H 1", - "# define HAVE_SYS_TIME_H 1", - "# define HAVE_SYS_TYPES_H 1", - "# define HAVE_SYS_UIO_H 1", - "# define HAVE_SYS_UN_H 1", - "# define HAVE_SYS_WAIT_H 1", - "# define HAVE_SYS_XATTR_H 1", - "# define HAVE_TIME_H 1", - "# define HAVE_UNAME 1", - "# define HAVE_UNISTD_H 1", - "# define HAVE_UTIME 1", - "# define HAVE_UTIME_H 1", - "# define HAVE_VARIADIC_MACROS_C99 1", - "# define HAVE_VARIADIC_MACROS_GCC 1", - "# define HAVE_WRITABLE_ARGV 1", - "# define HAVE_WRITEV 1", - "# define HAVE_ZLIB_H 1", - "# define LT_OBJDIR \".libs/\"", - "# define PACKAGE \"curl\"", - "# define PACKAGE_BUGREPORT \"a suitable curl mailing list: https://curl.haxx.se/mail/\"", - "# define PACKAGE_NAME \"curl\"", - "# define PACKAGE_STRING \"curl -\"", - "# define PACKAGE_TARNAME \"curl\"", - "# define PACKAGE_URL \"\"", - "# define PACKAGE_VERSION \"-\"", - "# define RECV_TYPE_ARG1 int", - "# define RECV_TYPE_ARG2 void *", - "# define RECV_TYPE_ARG3 size_t", - "# define RECV_TYPE_ARG4 int", - "# define RECV_TYPE_RETV ssize_t", - "# define RETSIGTYPE void", - "# define SELECT_QUAL_ARG5", - "# define SELECT_TYPE_ARG1 int", - "# define SELECT_TYPE_ARG234 fd_set *", - "# define SELECT_TYPE_ARG5 struct timeval *", - "# define SELECT_TYPE_RETV int", - "# define SEND_QUAL_ARG2 const", - "# define SEND_TYPE_ARG1 int", - "# define SEND_TYPE_ARG2 void *", - "# define SEND_TYPE_ARG3 size_t", - "# define SEND_TYPE_ARG4 int", - "# define SEND_TYPE_RETV ssize_t", - "# define SIZEOF_INT 4", - "# define SIZEOF_LONG 8", - "# define SIZEOF_OFF_T 8", - "# define SIZEOF_CURL_OFF_T 8", - "# define SIZEOF_SHORT 2", - "# define SIZEOF_SIZE_T 8", - "# define SIZEOF_TIME_T 8", - "# define SIZEOF_VOIDP 8", - "# define STDC_HEADERS 1", - "# define STRERROR_R_TYPE_ARG3 size_t", - "# define TIME_WITH_SYS_TIME 1", - "# define VERSION \"-\"", - "# ifndef _DARWIN_USE_64_BIT_INODE", - "# define _DARWIN_USE_64_BIT_INODE 1", - "# endif", - "#endif", - "", - "#endif // EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_", - "EOF", - ]), -) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 49027ea..11d73b1 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -26,7 +26,6 @@ def sf_serving_deps(): _com_github_pybind11_bazel() _com_github_pybind11() _com_github_opentelemetry_cpp() - _com_github_curl() # aws s3 _com_aws_c_common() @@ -34,7 +33,7 @@ def sf_serving_deps(): _com_aws_checksums() _com_aws_sdk() - _yacl() + _heu() _dataproxy() _com_github_brpc_brpc() @@ -47,15 +46,15 @@ def _dataproxy(): remote = "https://github.com/secretflow/dataproxy.git", ) -def _yacl(): +def _heu(): maybe( http_archive, - name = "yacl", + name = "com_alipay_sf_heu", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b2.tar.gz", + "https://github.com/secretflow/heu/archive/refs/tags/v0.5.1b0.tar.gz", ], - strip_prefix = "yacl-0.4.5b2", - sha256 = "b3fb75d41a32b80145a3bb9d36b8c039a262191f1a2f037292c649344289b01b", + strip_prefix = "heu-0.5.1b0", + sha256 = "26e81b55b2d4f734977f8d5e1ba672c20287f2971dcb5d7f3fa46295e2882012", ) def _bazel_rules_pkg(): @@ -77,19 +76,6 @@ def _bazel_platform(): sha256 = "218efe8ee736d26a3572663b374a253c012b716d8af0c07e842e82f238a0a7ee", ) -def _com_github_curl(): - maybe( - http_archive, - name = "com_github_curl", - build_file = "@sf_serving//bazel:curl.BUILD", - sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5", - strip_prefix = "curl-7.60.0", - urls = [ - "http://mirror.tensorflow.org/curl.haxx.se/download/curl-7.60.0.tar.gz", - "https://curl.haxx.se/download/curl-7.60.0.tar.gz", - ], - ) - def _com_aws_c_common(): maybe( http_archive, diff --git a/bazel/serving.bzl b/bazel/serving.bzl index b62c49b..396664a 100644 --- a/bazel/serving.bzl +++ b/bazel/serving.bzl @@ -52,7 +52,7 @@ def serving_cc_binary( deps = [], **kargs): cc_binary( - linkopts = linkopts, + linkopts = linkopts + ["-ldl"], copts = copts + _serving_copts(), deps = deps, **kargs diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/config.po b/docs/locales/zh_CN/LC_MESSAGES/reference/config.po index 636fec3..9a69c93 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/reference/config.po +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/config.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-07-31 14:20+0800\n" +"POT-Creation-Date: 2024-08-14 17:42+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -794,7 +794,7 @@ msgstr "" #: ../../source/reference/config.md msgid "" "Optional. Feature name mapping rules. Key: source or predefined feature " -"name Value: model feature name" +"name. Value: model feature name." msgstr "" #: ../../source/reference/config.md @@ -830,7 +830,7 @@ msgid "brpc_builtin_service_port" msgstr "" #: ../../source/reference/config.md -msgid "Brpc builtin service listen port Default: disable service" +msgid "Brpc builtin service listen port. Default: disable service" msgstr "" #: ../../source/reference/config.md @@ -838,7 +838,7 @@ msgid "metrics_exposer_port" msgstr "" #: ../../source/reference/config.md -msgid "`/metrics` service listen port Default: disable service" +msgid "`/metrics` service listen port. Default: disable service" msgstr "" #: ../../source/reference/config.md @@ -857,7 +857,7 @@ msgstr "" #: ../../source/reference/config.md msgid "" -"Server-level max number of requests processed in parallel Default: 0 " +"Server-level max number of requests processed in parallel. Default: 0 " "(unlimited)" msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/reference/model.po b/docs/locales/zh_CN/LC_MESSAGES/reference/model.po index b320a85..2af6862 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/reference/model.po +++ b/docs/locales/zh_CN/LC_MESSAGES/reference/model.po @@ -8,14 +8,14 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-02-19 12:47+0000\n" +"POT-Creation-Date: 2024-08-06 20:49+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=utf-8\n" "Content-Transfer-Encoding: 8bit\n" -"Generated-By: Babel 2.14.0\n" +"Generated-By: Babel 2.15.0\n" #: ../../source/reference/model.md:1 msgid "SecretFlow-Serving Model" @@ -29,7 +29,7 @@ msgstr "" msgid "Services" msgstr "" -#: ../../source/reference/model.md:28 ../../source/reference/model.md:155 +#: ../../source/reference/model.md:28 ../../source/reference/model.md:157 msgid "Messages" msgstr "" @@ -94,86 +94,94 @@ msgid "[GraphView](#graphview)" msgstr "" #: ../../source/reference/model.md:57 -msgid "[NodeDef](#nodedef)" +msgid "[HeConfig](#heconfig)" msgstr "" #: ../../source/reference/model.md:58 -msgid "[NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry)" +msgid "[HeInfo](#heinfo)" msgstr "" #: ../../source/reference/model.md:59 -msgid "[NodeView](#nodeview)" +msgid "[NodeDef](#nodedef)" msgstr "" #: ../../source/reference/model.md:60 +msgid "[NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry)" +msgstr "" + +#: ../../source/reference/model.md:61 +msgid "[NodeView](#nodeview)" +msgstr "" + +#: ../../source/reference/model.md:62 msgid "[RuntimeConfig](#runtimeconfig)" msgstr "" -#: ../../source/reference/model.md:66 +#: ../../source/reference/model.md:68 msgid "[ModelBundle](#modelbundle)" msgstr "" -#: ../../source/reference/model.md:67 +#: ../../source/reference/model.md:69 msgid "[ModelInfo](#modelinfo)" msgstr "" -#: ../../source/reference/model.md:68 +#: ../../source/reference/model.md:70 msgid "[ModelManifest](#modelmanifest)" msgstr "" -#: ../../source/reference/model.md:77 ../../source/reference/model.md:87 +#: ../../source/reference/model.md:79 ../../source/reference/model.md:89 msgid "[ComputeTrace](#computetrace)" msgstr "" -#: ../../source/reference/model.md:78 ../../source/reference/model.md:88 +#: ../../source/reference/model.md:80 ../../source/reference/model.md:90 msgid "[FunctionInput](#functioninput)" msgstr "" -#: ../../source/reference/model.md:79 ../../source/reference/model.md:89 +#: ../../source/reference/model.md:81 ../../source/reference/model.md:91 msgid "[FunctionOutput](#functionoutput)" msgstr "" -#: ../../source/reference/model.md:80 ../../source/reference/model.md:90 +#: ../../source/reference/model.md:82 ../../source/reference/model.md:92 msgid "[FunctionTrace](#functiontrace)" msgstr "" -#: ../../source/reference/model.md:81 ../../source/reference/model.md:91 +#: ../../source/reference/model.md:83 ../../source/reference/model.md:93 msgid "[Scalar](#scalar)" msgstr "" -#: ../../source/reference/model.md:96 ../../source/reference/model.md:644 +#: ../../source/reference/model.md:98 ../../source/reference/model.md:677 msgid "Enums" msgstr "" -#: ../../source/reference/model.md:100 +#: ../../source/reference/model.md:102 msgid "[AttrType](#attrtype)" msgstr "" -#: ../../source/reference/model.md:109 +#: ../../source/reference/model.md:111 msgid "[DispatchType](#dispatchtype)" msgstr "" -#: ../../source/reference/model.md:115 +#: ../../source/reference/model.md:117 msgid "[FileFormatType](#fileformattype)" msgstr "" -#: ../../source/reference/model.md:121 +#: ../../source/reference/model.md:123 msgid "[DataType](#datatype)" msgstr "" -#: ../../source/reference/model.md:127 ../../source/reference/model.md:133 +#: ../../source/reference/model.md:129 ../../source/reference/model.md:135 msgid "[ExtendFunctionName](#extendfunctionname)" msgstr "" -#: ../../source/reference/model.md:137 +#: ../../source/reference/model.md:139 msgid "[Scalar Value Types](#scalar-value-types)" msgstr "" -#: ../../source/reference/model.md:160 +#: ../../source/reference/model.md:162 msgid "AttrDef" msgstr "" -#: ../../source/reference/model.md:161 +#: ../../source/reference/model.md:163 msgid "The definition of an attribute." msgstr "" @@ -249,11 +257,11 @@ msgid "" "does not supply a value. If not, the user must supply a value." msgstr "" -#: ../../source/reference/model.md:176 +#: ../../source/reference/model.md:178 msgid "AttrValue" msgstr "" -#: ../../source/reference/model.md:177 +#: ../../source/reference/model.md:179 msgid "The value of an attribute" msgstr "" @@ -429,7 +437,7 @@ msgstr "" msgid "BYTESS" msgstr "" -#: ../../source/reference/model.md:201 +#: ../../source/reference/model.md:203 msgid "BoolList" msgstr "" @@ -441,7 +449,7 @@ msgstr "" msgid "[repeated bool](#bool )" msgstr "" -#: ../../source/reference/model.md:213 +#: ../../source/reference/model.md:215 msgid "BytesList" msgstr "" @@ -449,7 +457,7 @@ msgstr "" msgid "[repeated bytes](#bytes )" msgstr "" -#: ../../source/reference/model.md:225 +#: ../../source/reference/model.md:227 msgid "DoubleList" msgstr "" @@ -457,7 +465,7 @@ msgstr "" msgid "[repeated double](#double )" msgstr "" -#: ../../source/reference/model.md:237 +#: ../../source/reference/model.md:239 msgid "FloatList" msgstr "" @@ -465,7 +473,7 @@ msgstr "" msgid "[repeated float](#float )" msgstr "" -#: ../../source/reference/model.md:249 +#: ../../source/reference/model.md:251 msgid "Int32List" msgstr "" @@ -473,7 +481,7 @@ msgstr "" msgid "[repeated int32](#int32 )" msgstr "" -#: ../../source/reference/model.md:261 +#: ../../source/reference/model.md:263 msgid "Int64List" msgstr "" @@ -481,7 +489,7 @@ msgstr "" msgid "[repeated int64](#int64 )" msgstr "" -#: ../../source/reference/model.md:273 +#: ../../source/reference/model.md:275 msgid "StringList" msgstr "" @@ -489,11 +497,11 @@ msgstr "" msgid "[repeated string](#string )" msgstr "" -#: ../../source/reference/model.md:287 +#: ../../source/reference/model.md:289 msgid "IoDef" msgstr "" -#: ../../source/reference/model.md:288 +#: ../../source/reference/model.md:290 msgid "Define an input/output for operator." msgstr "" @@ -505,11 +513,11 @@ msgstr "" msgid "Description of the IO" msgstr "" -#: ../../source/reference/model.md:300 +#: ../../source/reference/model.md:302 msgid "OpDef" msgstr "" -#: ../../source/reference/model.md:301 +#: ../../source/reference/model.md:303 msgid "The definition of a operator." msgstr "" @@ -567,11 +575,11 @@ msgstr "" msgid "[repeated AttrDef](#attrdef )" msgstr "" -#: ../../source/reference/model.md:318 +#: ../../source/reference/model.md:320 msgid "OpTag" msgstr "" -#: ../../source/reference/model.md:319 +#: ../../source/reference/model.md:321 msgid "Representation operator property" msgstr "" @@ -598,7 +606,7 @@ msgid "session_run" msgstr "" #: ../../source/reference/model.md -msgid "The operator needs to be executed in session." +msgid "The operator needs to be executed in session. TODO: not supported yet." msgstr "" #: ../../source/reference/model.md @@ -609,11 +617,11 @@ msgstr "" msgid "Whether this op has variable input argument. default `false`." msgstr "" -#: ../../source/reference/model.md:335 +#: ../../source/reference/model.md:337 msgid "ExecutionDef" msgstr "" -#: ../../source/reference/model.md:336 +#: ../../source/reference/model.md:338 msgid "" "The definition of a execution. A execution represents a subgraph within a" " graph that can be scheduled for execution in a specified pattern." @@ -638,15 +646,15 @@ msgstr "" msgid "[ RuntimeConfig](#runtimeconfig )" msgstr "" -#: ../../source/reference/model.md ../../source/reference/model.md:424 +#: ../../source/reference/model.md ../../source/reference/model.md:457 msgid "The runtime config of the execution." msgstr "" -#: ../../source/reference/model.md:349 +#: ../../source/reference/model.md:351 msgid "GraphDef" msgstr "" -#: ../../source/reference/model.md:350 +#: ../../source/reference/model.md:352 msgid "" "The definition of a Graph. A graph consists of a set of nodes carrying " "data and a set of executions that describes the scheduling of the graph." @@ -672,11 +680,23 @@ msgstr "" msgid "[repeated ExecutionDef](#executiondef )" msgstr "" -#: ../../source/reference/model.md:364 +#: ../../source/reference/model.md +msgid "he_config" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ HeConfig](#heconfig )" +msgstr "" + +#: ../../source/reference/model.md +msgid "party_num" +msgstr "" + +#: ../../source/reference/model.md:368 msgid "GraphView" msgstr "" -#: ../../source/reference/model.md:365 +#: ../../source/reference/model.md:369 msgid "" "The view of a graph is used to display the structure of the graph, " "containing only structural information and excluding the data components." @@ -686,11 +706,59 @@ msgstr "" msgid "[repeated NodeView](#nodeview )" msgstr "" -#: ../../source/reference/model.md:379 +#: ../../source/reference/model.md +msgid "he_info" +msgstr "" + +#: ../../source/reference/model.md +msgid "[ HeInfo](#heinfo )" +msgstr "" + +#: ../../source/reference/model.md:385 +msgid "HeConfig" +msgstr "" + +#: ../../source/reference/model.md:386 +msgid "The config for HE compute." +msgstr "" + +#: ../../source/reference/model.md +msgid "pk_buf" +msgstr "" + +#: ../../source/reference/model.md +msgid "Serialized public key bytes" +msgstr "" + +#: ../../source/reference/model.md +msgid "sk_buf" +msgstr "" + +#: ../../source/reference/model.md +msgid "Serialized secret key bytes" +msgstr "" + +#: ../../source/reference/model.md +msgid "encode_scale" +msgstr "" + +#: ../../source/reference/model.md +msgid "Encode scale for data" +msgstr "" + +#: ../../source/reference/model.md:399 +msgid "HeInfo" +msgstr "" + +#: ../../source/reference/model.md:400 +msgid "The public info for HE compute." +msgstr "" + +#: ../../source/reference/model.md:412 msgid "NodeDef" msgstr "" -#: ../../source/reference/model.md:380 +#: ../../source/reference/model.md:413 msgid "The definition of a node." msgstr "" @@ -738,7 +806,7 @@ msgstr "" msgid "The operator version." msgstr "" -#: ../../source/reference/model.md:395 +#: ../../source/reference/model.md:428 msgid "NodeDef.AttrValuesEntry" msgstr "" @@ -754,15 +822,15 @@ msgstr "" msgid "[ op.AttrValue](#attrvalue )" msgstr "" -#: ../../source/reference/model.md:408 +#: ../../source/reference/model.md:441 msgid "NodeView" msgstr "" -#: ../../source/reference/model.md:409 +#: ../../source/reference/model.md:442 msgid "The view of a node, which could be public to other parties" msgstr "" -#: ../../source/reference/model.md:423 +#: ../../source/reference/model.md:456 msgid "RuntimeConfig" msgstr "" @@ -790,11 +858,11 @@ msgstr "" msgid "if dispatch_type is DP_SPECIFIED, only one party should be true" msgstr "" -#: ../../source/reference/model.md:439 +#: ../../source/reference/model.md:472 msgid "ModelBundle" msgstr "" -#: ../../source/reference/model.md:440 +#: ../../source/reference/model.md:473 msgid "" "Represents an exported secertflow model. It consists of a GraphDef and " "extra metadata required for serving." @@ -808,11 +876,11 @@ msgstr "" msgid "[ GraphDef](#graphdef )" msgstr "" -#: ../../source/reference/model.md:454 +#: ../../source/reference/model.md:487 msgid "ModelInfo" msgstr "" -#: ../../source/reference/model.md:455 +#: ../../source/reference/model.md:488 msgid "Represents a secertflow model without private data." msgstr "" @@ -824,11 +892,11 @@ msgstr "" msgid "[ GraphView](#graphview )" msgstr "" -#: ../../source/reference/model.md:468 +#: ../../source/reference/model.md:501 msgid "ModelManifest" msgstr "" -#: ../../source/reference/model.md:469 +#: ../../source/reference/model.md:502 msgid "" "The manifest of the model package. Package format is as follows: " "model.tar.gz ├ MANIFIEST ├ model_file └ some op meta files MANIFIEST " @@ -855,7 +923,7 @@ msgstr "" msgid "The format type of the model bundle file." msgstr "" -#: ../../source/reference/model.md:490 ../../source/reference/model.md:568 +#: ../../source/reference/model.md:523 ../../source/reference/model.md:601 msgid "ComputeTrace" msgstr "" @@ -871,7 +939,7 @@ msgstr "" msgid "[repeated FunctionTrace](#functiontrace )" msgstr "" -#: ../../source/reference/model.md:503 ../../source/reference/model.md:581 +#: ../../source/reference/model.md:536 ../../source/reference/model.md:614 msgid "FunctionInput" msgstr "" @@ -895,7 +963,7 @@ msgstr "" msgid "[ Scalar](#scalar )" msgstr "" -#: ../../source/reference/model.md:516 ../../source/reference/model.md:594 +#: ../../source/reference/model.md:549 ../../source/reference/model.md:627 msgid "FunctionOutput" msgstr "" @@ -903,7 +971,7 @@ msgstr "" msgid "data_id" msgstr "" -#: ../../source/reference/model.md:528 ../../source/reference/model.md:606 +#: ../../source/reference/model.md:561 ../../source/reference/model.md:639 msgid "FunctionTrace" msgstr "" @@ -935,11 +1003,11 @@ msgstr "" msgid "Output of this function." msgstr "" -#: ../../source/reference/model.md:543 ../../source/reference/model.md:621 +#: ../../source/reference/model.md:576 ../../source/reference/model.md:654 msgid "Scalar" msgstr "" -#: ../../source/reference/model.md:544 ../../source/reference/model.md:622 +#: ../../source/reference/model.md:577 ../../source/reference/model.md:655 msgid "Represents a single value with a specific data type." msgstr "" @@ -1023,11 +1091,11 @@ msgstr "" msgid "DOUBLE" msgstr "" -#: ../../source/reference/model.md:649 +#: ../../source/reference/model.md:682 msgid "AttrType" msgstr "" -#: ../../source/reference/model.md:650 +#: ../../source/reference/model.md:683 msgid "Supported attribute types." msgstr "" @@ -1191,11 +1259,11 @@ msgstr "" msgid "BYTES LIST" msgstr "" -#: ../../source/reference/model.md:678 +#: ../../source/reference/model.md:711 msgid "DispatchType" msgstr "" -#: ../../source/reference/model.md:679 +#: ../../source/reference/model.md:712 msgid "Supported dispatch type" msgstr "" @@ -1227,11 +1295,27 @@ msgstr "" msgid "Dispatch specified participant." msgstr "" -#: ../../source/reference/model.md:694 +#: ../../source/reference/model.md +msgid "DP_SELF" +msgstr "" + +#: ../../source/reference/model.md +msgid "Dispatch self." +msgstr "" + +#: ../../source/reference/model.md +msgid "DP_PEER" +msgstr "" + +#: ../../source/reference/model.md +msgid "For 2-parties, Dispatch peer participant." +msgstr "" + +#: ../../source/reference/model.md:729 msgid "FileFormatType" msgstr "" -#: ../../source/reference/model.md:695 +#: ../../source/reference/model.md:730 msgid "Support model file format" msgstr "" @@ -1257,11 +1341,11 @@ msgid "" "method to ensure compatibility" msgstr "" -#: ../../source/reference/model.md:709 +#: ../../source/reference/model.md:744 msgid "DataType" msgstr "" -#: ../../source/reference/model.md:710 +#: ../../source/reference/model.md:745 msgid "" "Mapping arrow::DataType " "`https://arrow.apache.org/docs/cpp/api/datatype.html`." @@ -1383,7 +1467,7 @@ msgstr "" msgid "Variable-length bytes (no guarantee of UTF8-ness)" msgstr "" -#: ../../source/reference/model.md:736 ../../source/reference/model.md:753 +#: ../../source/reference/model.md:771 ../../source/reference/model.md:788 msgid "ExtendFunctionName" msgstr "" @@ -1435,7 +1519,7 @@ msgid "" "https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE" msgstr "" -#: ../../source/reference/model.md:768 +#: ../../source/reference/model.md:803 msgid "Scalar Value Types" msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po index 9216fec..80409a8 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/deployment.po @@ -283,7 +283,7 @@ msgstr "步骤 2:启动 Serving 服务" msgid "" "The file your workspace should be as follows, ``trace.config`` is " "optional:" -msgstr "您工作区的文件应如下所示,``trace.config``是可选的:" +msgstr "您工作区的文件应如下所示,``trace.config`` 是可选的:" #: ../../source/topics/deployment/deployment.rst:200 msgid "Then you can start serving service by running docker compose up." diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po index 8d4efbe..32e2e07 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-07-31 14:26+0800\n" +"POT-Creation-Date: 2024-08-20 11:45+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -48,7 +48,7 @@ msgstr "" "Hans/reference/concepts/appimage_cn#id2>`_ 。SecretFlow-Serving 的 " "`AppImage` 如下所示:" -#: ../../source/topics/deployment/serving_on_kuscia.rst:78 +#: ../../source/topics/deployment/serving_on_kuscia.rst:85 msgid "" "The explanation of the common fields can be found `here " "`_ 。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:80 +#: ../../source/topics/deployment/serving_on_kuscia.rst:87 msgid "Other field explanations are as follows:" msgstr "其他字段说明如下:" -#: ../../source/topics/deployment/serving_on_kuscia.rst:86 -msgid "`configTemplates`:" +#: ../../source/topics/deployment/serving_on_kuscia.rst:99 +msgid "**configTemplates**:" +msgstr "" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:96 +msgid "**serving-config.conf**:" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:83 +#: ../../source/topics/deployment/serving_on_kuscia.rst:91 msgid "" -"`serving_id`: Service ID identifier, corresponding to the configuration " -":ref:`ServingConfig.id `. The current content is a " +"**serving_id**: Service ID identifier, corresponding to the configuration" +" :ref:`ServingConfig.id `. The current content is a " "placeholder and will actually be replaced by the content in `Kuscia API " "/api/v1/serving/create `_ at startup." msgstr "" -"`serving_id`: 服务ID标识,对应配置 :ref:`ServingConfig.id " +"**serving_id**: 服务ID标识,对应配置 :ref:`ServingConfig.id " "`。当前内容为占位符,实际会在启动时被 `Kuscia API /api/v1/serving/create " "`_ 中的内容替换。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:84 +#: ../../source/topics/deployment/serving_on_kuscia.rst:92 msgid "" -"`input_config`: SecretFlow-Serving startup configuration, details can be " -"seen in the description below. The current content is a placeholder and " -"will actually be replaced by the content in `Kuscia API " +"**input_config**: SecretFlow-Serving startup configuration, details can " +"be seen in the description below. The current content is a placeholder " +"and will actually be replaced by the content in `Kuscia API " "/api/v1/serving/create `_ at startup." msgstr "" -"`input_config`: SecretFlow-Serving 启动配置,详情可见下面的描述。当前内容为占位符,实际会在启动时被 " +"**input_config**: SecretFlow-Serving 启动配置,详情可见下面的描述。当前内容为占位符,实际会在启动时被 " "`Kuscia API /api/v1/serving/create " "`_ 中的内容替换。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:85 +#: ../../source/topics/deployment/serving_on_kuscia.rst:93 msgid "" -"`cluster_def`: See `AppImage explanation " +"**cluster_def**: See `AppImage explanation " "`_." msgstr "" -"`cluster_def`:见 `AppImage 说明 " +"**cluster_def**:见 `AppImage 说明 " "`_。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:86 +#: ../../source/topics/deployment/serving_on_kuscia.rst:94 msgid "" -"`allocated_ports`: See `AppImage explanation " +"**allocated_ports**: See `AppImage explanation " "`_." msgstr "" -"`allocated_ports`: 见 `AppImage 说明 " +"**allocated_ports**: 见 `AppImage 说明 " "`_。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:92 -msgid "`ports`:" +#: ../../source/topics/deployment/serving_on_kuscia.rst:95 +msgid "" +"**oss_meta**: OSS/S3 model source configuration, only effective when " +"using OSS/S3 as the model data source. The actual content is in the form " +"of a string-formatted JSON configuration, for example\" " +"``{\\\"access_key\\\":\\\"test_ak\\\", " +"\\\"secret_key\\\":\\\"test_sk\\\", \\\"virtual_hosted\\\":true, " +"\\\"endpoint\\\":\\\"test_endpoint\\\", " +"\\\"bucket\\\":\\\"test_bucket\\\"}``, the definition can be found " +":ref:`here `. This is an optional configuration and can be" +" set up through the ``Kuscia configuration management`` system if needed." +msgstr "" +"**oss_meta**:OSS/S3 " +"配置,只当使用OSS/S3作为模型数据源时生效。其真实内容为字符串格式化的Json配置,例如:``{\\\"access_key\\\":\\\"test_ak\\\",\"\\\"secret_key\\\":\\\"test_sk\\\"," +" \\\"virtual_hosted\\\":true, \\\"endpoint\\\":\\\"test_endpoint\\\", " +"\\\"bucket\\\":\\\"test_bucket\\\"}``,其具体定义可见 :ref:`这里 ` " +"该配置为可选配置,若需要使用可通过`Kuscia 配置管理系统`进行配置。" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:96 +msgid "" +"**spi_tls_config**: The TLS configuration used by SPI. The actual content" +" is in the form of a string-formatted JSON configuration, for example: " +"``{\\\"certificate_path\\\":\\\"abc\\\", " +"\\\"private_key_path\\\":\\\"def\\\",\\\"ca_file_path\\\":\\\"gkh\\\"}``," +" the definition can be found :ref:`here `. This is an optional" +" configuration and can be set up through the ``Kuscia configuration " +"management`` system if needed." +msgstr "" +"**spi_tls_config**:SPI链路使用的TLS配置。其真实内容为字符串格式化的Json配置,例如:``{\\\"certificate_path\\\":\\\"abc\\\",\"\\\"private_key_path\\\":\\\"def\\\",\\\"ca_file_path\\\":\\\"gkh\\\"}``,具体定义可见" +" :ref:`这里 ` 。该配置为可选配置,若需要使用可通过`Kuscia 配置管理系统`进行配置。" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:97 +msgid "" +"**http_source_meta**: HTTP model source configuration, only effective " +"when using HTTP as the model data source. The actual content is in the " +"form of a string-formatted JSON configuration, for example\" " +"``{\\\"connectTimeoutMs\\\":60000,\\\"timeoutMs\\\":120000,\\\"tlsConfig\\\":{\\\"certificatePath\\\":\\\"abc\\\"," +" \\\"privateKeyPath\\\":\\\"def\\\",\\\"caFilePath\\\":\\\"gkh\\\"}}``, " +"the definition can be found :ref:`here `. This is an " +"optional configuration and can be set up through the ``Kuscia " +"configuration management`` system if needed." +msgstr "" +"**http_source_meta**:HTTP 数据源" +"配置,只当使用HTTP服务作为模型数据源时生效。其真实内容为字符串格式化的Json配置,例如:" +"``{\\\"connectTimeoutMs\\\":60000,\\\"timeoutMs\\\":120000,\\\"tlsConfig\\\":{\\\"certificatePath\\\":\\\"abc\\\"," +" \\\"privateKeyPath\\\":\\\"def\\\",\\\"caFilePath\\\":\\\"gkh\\\"}}``" +",其具体定义可见 :ref:`这里 ` " +"该配置为可选配置,若需要使用可通过`Kuscia 配置管理系统`进行配置。" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:98 +msgid "" +"**logging-config.conf**: SecretFlow-Serving Log Configuration. The actual" +" content is in the form of a string-formatted JSON configuration, for " +"example\" " +"``{\\\"systemLogPath\\\":\\\"/tmp/alice/serving.log\\\",\\\"logLevel\\\":\\\"INFO_LOG_LEVEL\\\",\\\"maxLogFileSize\\\":4194304,\\\"maxLogFileCount\\\":10}``," +" the definition can be found :ref:`here `. This is an " +"optional configuration and can be set up through the ``Kuscia " +"configuration management`` system if needed." +msgstr "" +"**logging-" +"config.conf**:日志配置。其真实内容为字符串格式化的Json配置,例如:``{\\\"systemLogPath\\\":\\\"/tmp/alice/serving.log\\\",\\\"logLevel\\\":\\\"INFO_LOG_LEVEL\\\",\\\"maxLogFileSize\\\":4194304,\\\"maxLogFileCount\\\":10}``,具体定义可见" +" :ref:`这里 ` 。该配置为可选配置,若需要使用可通过`Kuscia 配置管理系统`进行配置。" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:99 +msgid "" +"**trace-config.conf**: SecretFlow-Serving Trace Configuration. The actual" +" content is in the form of a string-formatted JSON configuration, for " +"example\" " +"``{\\\"traceLogEnable\\\":true,\\\"traceLogConf\\\":{\\\"traceLogPath\\\":\\\"/tmp/trace.log\\\"}}``," +" the definition can be found :ref:`here `. This is an " +"optional configuration and can be set up through the ``Kuscia " +"configuration management`` system if needed." msgstr "" +"**trace-" +"config.conf**:Trace配置。其真实内容为字符串格式化的Json配置,例如:``{\\\"traceLogEnable\\\":true,\\\"traceLogConf\\\":{\\\"traceLogPath\\\":\\\"/tmp/trace.log\\\"}}``,具体定义可见" +" :ref:`这里 ` 。该配置为可选配置,若需要使用可通过`Kuscia 配置管理系统`进行配置。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:89 -msgid "`service`: The :ref:`ServerConfig.service_port `" -msgstr "`service`: 即 :ref:`ServerConfig.service_port `" +#: ../../source/topics/deployment/serving_on_kuscia.rst:105 +msgid "**ports**:" +msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:90 -msgid "`communication`: The :ref:`ServerConfig.communication_port `" -msgstr "`communication`:即 :ref:`ServerConfig.communication_port `" +#: ../../source/topics/deployment/serving_on_kuscia.rst:102 +msgid "**service**: The :ref:`ServerConfig.service_port `" +msgstr "**service**: 即 :ref:`ServerConfig.service_port `" -#: ../../source/topics/deployment/serving_on_kuscia.rst:91 -msgid "`internal`: The :ref:`ServerConfig.metrics_exposer_port `" -msgstr "`internal`:即 :ref:`ServerConfig.metrics_exposer_port `" +#: ../../source/topics/deployment/serving_on_kuscia.rst:103 +msgid "" +"**communication**: The :ref:`ServerConfig.communication_port " +"`" +msgstr "**communication**:即 :ref:`ServerConfig.communication_port `" -#: ../../source/topics/deployment/serving_on_kuscia.rst:92 +#: ../../source/topics/deployment/serving_on_kuscia.rst:104 +msgid "**internal**: The :ref:`ServerConfig.metrics_exposer_port `" +msgstr "**internal**:即 :ref:`ServerConfig.metrics_exposer_port `" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:105 msgid "" -"`brpc-builtin`: The :ref:`ServerConfig.brpc_builtin_service_port " +"**brpc-builtin**: The :ref:`ServerConfig.brpc_builtin_service_port " "`" msgstr "" -"`brpc-builtin`:即 :ref:`ServerConfig.brpc_builtin_service_port " +"**brpc-builtin**:即 :ref:`ServerConfig.brpc_builtin_service_port " "`" -#: ../../source/topics/deployment/serving_on_kuscia.rst:95 +#: ../../source/topics/deployment/serving_on_kuscia.rst:108 msgid "Configuration description" msgstr "配置说明" -#: ../../source/topics/deployment/serving_on_kuscia.rst:98 +#: ../../source/topics/deployment/serving_on_kuscia.rst:111 msgid "serving_input_config" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:100 +#: ../../source/topics/deployment/serving_on_kuscia.rst:113 msgid "" "The launch and management of SecretFlow-Serving can be performed using " "the `Kuscia Serving API " "`_. In this section, we will " -"explain the contents of the `serving_input_config` field within the " -"`/api/v1/serving/create` request." +"explain the contents of the ``serving_input_config`` field within the " +"``/api/v1/serving/create`` request." msgstr "" "Kuscia 场景下,SecretFlow-Serving 的启动管理可通过 `Kuscia Serving API " "`_ 进行,这里我们将对 " -"`/api/v1/serving/create` 请求中的 `serving_input_config` 字段内容进行说明。" +"``/api/v1/serving/create`` 请求中的 ``serving_input_config`` 字段内容进行说明。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:155 +#: ../../source/topics/deployment/serving_on_kuscia.rst:168 msgid "**Field description**:" msgstr "**字段说明**:" -#: ../../source/topics/deployment/serving_on_kuscia.rst:158 +#: ../../source/topics/deployment/serving_on_kuscia.rst:171 msgid "Name" msgstr "名称" -#: ../../source/topics/deployment/serving_on_kuscia.rst:158 +#: ../../source/topics/deployment/serving_on_kuscia.rst:171 msgid "Type" msgstr "类型" -#: ../../source/topics/deployment/serving_on_kuscia.rst:158 +#: ../../source/topics/deployment/serving_on_kuscia.rst:171 msgid "Description" msgstr "描述" -#: ../../source/topics/deployment/serving_on_kuscia.rst:158 +#: ../../source/topics/deployment/serving_on_kuscia.rst:171 msgid "Required" msgstr "必选" -#: ../../source/topics/deployment/serving_on_kuscia.rst:160 +#: ../../source/topics/deployment/serving_on_kuscia.rst:173 msgid "partyConfigs" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:160 +#: ../../source/topics/deployment/serving_on_kuscia.rst:173 msgid "map" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:160 +#: ../../source/topics/deployment/serving_on_kuscia.rst:173 msgid "" "Dictionary of startup parameters for each participant. Key: Participant " "Unique ID; Value: PartyConfig (Json Object)." msgstr "各参与方启动参数字典。Key: 参与方id; Value: PartyConfig(Json Object)" -#: ../../source/topics/deployment/serving_on_kuscia.rst:160 -#: ../../source/topics/deployment/serving_on_kuscia.rst:162 -#: ../../source/topics/deployment/serving_on_kuscia.rst:166 -#: ../../source/topics/deployment/serving_on_kuscia.rst:168 -#: ../../source/topics/deployment/serving_on_kuscia.rst:170 -#: ../../source/topics/deployment/serving_on_kuscia.rst:172 -#: ../../source/topics/deployment/serving_on_kuscia.rst:176 +#: ../../source/topics/deployment/serving_on_kuscia.rst:173 +#: ../../source/topics/deployment/serving_on_kuscia.rst:175 +#: ../../source/topics/deployment/serving_on_kuscia.rst:179 +#: ../../source/topics/deployment/serving_on_kuscia.rst:181 #: ../../source/topics/deployment/serving_on_kuscia.rst:183 -#: ../../source/topics/deployment/serving_on_kuscia.rst:191 -#: ../../source/topics/deployment/serving_on_kuscia.rst:201 -#: ../../source/topics/deployment/serving_on_kuscia.rst:203 -#: ../../source/topics/deployment/serving_on_kuscia.rst:205 -#: ../../source/topics/deployment/serving_on_kuscia.rst:207 +#: ../../source/topics/deployment/serving_on_kuscia.rst:185 +#: ../../source/topics/deployment/serving_on_kuscia.rst:190 +#: ../../source/topics/deployment/serving_on_kuscia.rst:198 +#: ../../source/topics/deployment/serving_on_kuscia.rst:206 +#: ../../source/topics/deployment/serving_on_kuscia.rst:216 +#: ../../source/topics/deployment/serving_on_kuscia.rst:218 +#: ../../source/topics/deployment/serving_on_kuscia.rst:220 +#: ../../source/topics/deployment/serving_on_kuscia.rst:222 msgid "Yes" msgstr "是" -#: ../../source/topics/deployment/serving_on_kuscia.rst:162 +#: ../../source/topics/deployment/serving_on_kuscia.rst:175 msgid "PartyConfig.serverConfig" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:162 -#: ../../source/topics/deployment/serving_on_kuscia.rst:168 -#: ../../source/topics/deployment/serving_on_kuscia.rst:170 -#: ../../source/topics/deployment/serving_on_kuscia.rst:172 -#: ../../source/topics/deployment/serving_on_kuscia.rst:174 -#: ../../source/topics/deployment/serving_on_kuscia.rst:176 +#: ../../source/topics/deployment/serving_on_kuscia.rst:175 #: ../../source/topics/deployment/serving_on_kuscia.rst:181 +#: ../../source/topics/deployment/serving_on_kuscia.rst:183 +#: ../../source/topics/deployment/serving_on_kuscia.rst:185 #: ../../source/topics/deployment/serving_on_kuscia.rst:187 -#: ../../source/topics/deployment/serving_on_kuscia.rst:191 -#: ../../source/topics/deployment/serving_on_kuscia.rst:207 +#: ../../source/topics/deployment/serving_on_kuscia.rst:189 +#: ../../source/topics/deployment/serving_on_kuscia.rst:196 +#: ../../source/topics/deployment/serving_on_kuscia.rst:202 +#: ../../source/topics/deployment/serving_on_kuscia.rst:206 +#: ../../source/topics/deployment/serving_on_kuscia.rst:222 msgid "str" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:162 +#: ../../source/topics/deployment/serving_on_kuscia.rst:175 msgid ":ref:`ServerConfig `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:164 +#: ../../source/topics/deployment/serving_on_kuscia.rst:177 msgid "PartyConfig.serverConfig.featureMapping" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:164 +#: ../../source/topics/deployment/serving_on_kuscia.rst:177 msgid "map" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:164 +#: ../../source/topics/deployment/serving_on_kuscia.rst:177 msgid "" "Feature name mapping rules. Key: source or predefined feature name; " "Value: model feature name" msgstr "特征名称映射规则。Key: 数据源或请求定义中的特征名称;Value: 模型中使用的特征名称。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:164 -#: ../../source/topics/deployment/serving_on_kuscia.rst:174 -#: ../../source/topics/deployment/serving_on_kuscia.rst:181 +#: ../../source/topics/deployment/serving_on_kuscia.rst:177 #: ../../source/topics/deployment/serving_on_kuscia.rst:187 -#: ../../source/topics/deployment/serving_on_kuscia.rst:193 -#: ../../source/topics/deployment/serving_on_kuscia.rst:195 -#: ../../source/topics/deployment/serving_on_kuscia.rst:197 -#: ../../source/topics/deployment/serving_on_kuscia.rst:209 -#: ../../source/topics/deployment/serving_on_kuscia.rst:211 +#: ../../source/topics/deployment/serving_on_kuscia.rst:196 +#: ../../source/topics/deployment/serving_on_kuscia.rst:202 +#: ../../source/topics/deployment/serving_on_kuscia.rst:208 +#: ../../source/topics/deployment/serving_on_kuscia.rst:210 +#: ../../source/topics/deployment/serving_on_kuscia.rst:212 +#: ../../source/topics/deployment/serving_on_kuscia.rst:224 +#: ../../source/topics/deployment/serving_on_kuscia.rst:226 msgid "No" msgstr "否" -#: ../../source/topics/deployment/serving_on_kuscia.rst:166 +#: ../../source/topics/deployment/serving_on_kuscia.rst:179 msgid "PartyConfig.modelConfig" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:166 #: ../../source/topics/deployment/serving_on_kuscia.rst:179 -#: ../../source/topics/deployment/serving_on_kuscia.rst:183 -#: ../../source/topics/deployment/serving_on_kuscia.rst:185 -#: ../../source/topics/deployment/serving_on_kuscia.rst:189 -#: ../../source/topics/deployment/serving_on_kuscia.rst:199 -#: ../../source/topics/deployment/serving_on_kuscia.rst:201 -#: ../../source/topics/deployment/serving_on_kuscia.rst:203 -#: ../../source/topics/deployment/serving_on_kuscia.rst:205 +#: ../../source/topics/deployment/serving_on_kuscia.rst:194 +#: ../../source/topics/deployment/serving_on_kuscia.rst:198 +#: ../../source/topics/deployment/serving_on_kuscia.rst:200 +#: ../../source/topics/deployment/serving_on_kuscia.rst:204 +#: ../../source/topics/deployment/serving_on_kuscia.rst:214 +#: ../../source/topics/deployment/serving_on_kuscia.rst:216 +#: ../../source/topics/deployment/serving_on_kuscia.rst:218 +#: ../../source/topics/deployment/serving_on_kuscia.rst:220 msgid "Object" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:166 +#: ../../source/topics/deployment/serving_on_kuscia.rst:179 msgid ":ref:`ModelConfig `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:168 +#: ../../source/topics/deployment/serving_on_kuscia.rst:181 msgid "PartyConfig.modelConfig.modelId" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:168 +#: ../../source/topics/deployment/serving_on_kuscia.rst:181 msgid "Unique id of the model package" msgstr "模型包标识ID" -#: ../../source/topics/deployment/serving_on_kuscia.rst:170 +#: ../../source/topics/deployment/serving_on_kuscia.rst:183 msgid "PartyConfig.modelConfig.basePath" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:170 +#: ../../source/topics/deployment/serving_on_kuscia.rst:183 msgid "The local path used to cache and load model package" msgstr "本地缓存路径,用于缓存模型包数据" -#: ../../source/topics/deployment/serving_on_kuscia.rst:172 +#: ../../source/topics/deployment/serving_on_kuscia.rst:185 msgid "PartyConfig.modelConfig.sourcePath" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:172 +#: ../../source/topics/deployment/serving_on_kuscia.rst:185 msgid "" "The path to the model package in the data source, where the content " -"format may vary depending on the `sourceType`." -msgstr "模型包路径,其具体内容格式取决于参数`sourceType`" +"format may vary depending on the ``sourceType``." +msgstr "模型包路径,其具体内容格式取决于参数 ``sourceType``" -#: ../../source/topics/deployment/serving_on_kuscia.rst:174 -msgid "PartyConfig.modelConfig.source_sha256" +#: ../../source/topics/deployment/serving_on_kuscia.rst:187 +msgid "PartyConfig.modelConfig.sourceSha256" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:174 +#: ../../source/topics/deployment/serving_on_kuscia.rst:187 msgid "" "The expected SHA256 hash of the model package. When provided, the fetched" " model package will be verified against it." msgstr "期望的模型包SHA256哈希值。提供时,会被用于校验获取的模型包是否匹配。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:176 +#: ../../source/topics/deployment/serving_on_kuscia.rst:189 msgid "PartyConfig.modelConfig.sourceType" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:176 +#: ../../source/topics/deployment/serving_on_kuscia.rst:189 msgid "" -"Model data source type, options include: ST_FILE: In this case, the " -"sourcePath should be a file path accessible to Serving. ST_DP: In this " -"case, the sourcePath should be DomainData ID in DataMesh from Kuscia. and" -" dpSourceMeta needs to be configured." -msgstr "" -"模型数据源类型,可选内容: ST_FILE: 此时`sourcePath`应为文件系统路径。ST_DP: " -"此时`sourcePath`应为Kuscia DataMesh管理的 DomainData ID, 同时需要配置`dpSourceMeta`" - -#: ../../source/topics/deployment/serving_on_kuscia.rst:179 +"Model data source type, options include: ``ST_FILE``: In this case, the " +"``sourcePath`` should be a file path accessible to Serving. ``ST_DP``: In" +" this case, the ``sourcePath`` should be DomainData ID in DataMesh from " +"Kuscia. and dpSourceMeta needs to be configured. ``ST_OSS``: In this " +"case, the ``sourcePath`` should be the file path within the bucket. " +"``ST_HTTP``: In this case, the ``sourcePath`` should be the download URL " +"for the model package." +msgstr "" +"模型数据源类型,可选内容: ``ST_FILE``: 此时 ``sourcePath`` 应为文件系统路径。 ``ST_DP``: 此时 " +"``sourcePath`` 应为 Kuscia DataMesh 管理的 DomainData ID, 同时需要配置 " +"``dpSourceMeta``。 ``ST_OSS``: 此时 ``sourcePath`` 应为 Bucket 下的文件路径。 " +"``ST_HTTP``: 此时 ``sourcePath`` 应为模型包的下载链接。" + +#: ../../source/topics/deployment/serving_on_kuscia.rst:194 msgid "PartyConfig.modelConfig.dpSourceMeta" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:179 +#: ../../source/topics/deployment/serving_on_kuscia.rst:194 msgid ":ref:`DPSourceMeta `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:179 -msgid "No(If `sourceType` is `DT_DP`, `dpSourceMeta` needs to be configured)" -msgstr "" +#: ../../source/topics/deployment/serving_on_kuscia.rst:194 +msgid "" +"No(If ``sourceType`` is ``DT_DP``, ``dpSourceMeta`` needs to be " +"configured)" +msgstr "否(只当 ``sourceType`` 为 ``DT_DP`` 时, ``dpSourceMeta`` 需要被配置)" -#: ../../source/topics/deployment/serving_on_kuscia.rst:181 +#: ../../source/topics/deployment/serving_on_kuscia.rst:196 msgid "PartyConfig.modelConfig.dpSourceMeta.dmHost" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:181 +#: ../../source/topics/deployment/serving_on_kuscia.rst:196 msgid "The address of DataMesh in Kuscia. Default: datamesh:8071" msgstr "Kusica DataMesh地址, 默认值: datamesh:8071" -#: ../../source/topics/deployment/serving_on_kuscia.rst:183 +#: ../../source/topics/deployment/serving_on_kuscia.rst:198 msgid "PartyConfig.featureSourceConfig" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:183 +#: ../../source/topics/deployment/serving_on_kuscia.rst:198 msgid ":ref:`FeatureSourceConfig `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:185 +#: ../../source/topics/deployment/serving_on_kuscia.rst:200 msgid "PartyConfig.featureSourceConfig.mockOpts" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:185 +#: ../../source/topics/deployment/serving_on_kuscia.rst:200 msgid ":ref:`MockOptions `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:185 -#: ../../source/topics/deployment/serving_on_kuscia.rst:189 -#: ../../source/topics/deployment/serving_on_kuscia.rst:199 -msgid "No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured)" -msgstr "否(`csvOpts`、`mockOpts`、`httpOpts`中的之一需要被配置)" +#: ../../source/topics/deployment/serving_on_kuscia.rst:200 +#: ../../source/topics/deployment/serving_on_kuscia.rst:204 +#: ../../source/topics/deployment/serving_on_kuscia.rst:214 +msgid "" +"No(One of ``csvOpts``, ``mockOpts``, or ``httpOpts`` needs to be " +"configured)" +msgstr "否(``csvOpts``, ``mockOpts`` 或者 ``httpOpts`` 中的之一需要被配置)" -#: ../../source/topics/deployment/serving_on_kuscia.rst:187 +#: ../../source/topics/deployment/serving_on_kuscia.rst:202 msgid "PartyConfig.featureSourceConfig.mockOpts.type" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:187 +#: ../../source/topics/deployment/serving_on_kuscia.rst:202 msgid "" "The method for generating mock feature values, options: \"MDT_RANDOM\" " "for random values, and \"MDT_FIXED\" for fixed values. Default: " "\"MDT_FIXED\"." msgstr "mock特征数据生成方法类型,可选:\"MDT_RANDOM\"用于生成随机值;\"MDT_FIXED\"返回固定值。默认设置为:\"MDT_FIXED\"。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:189 +#: ../../source/topics/deployment/serving_on_kuscia.rst:204 msgid "PartyConfig.featureSourceConfig.httpOpts" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:189 +#: ../../source/topics/deployment/serving_on_kuscia.rst:204 msgid ":ref:`HttpOptions `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:191 +#: ../../source/topics/deployment/serving_on_kuscia.rst:206 msgid "PartyConfig.featureSourceConfig.httpOpts.endpoint" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:191 +#: ../../source/topics/deployment/serving_on_kuscia.rst:206 msgid "Feature service address" -msgstr "" +msgstr "特征服务地址" -#: ../../source/topics/deployment/serving_on_kuscia.rst:193 +#: ../../source/topics/deployment/serving_on_kuscia.rst:208 msgid "PartyConfig.featureSourceConfig.httpOpts.enableLb" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:193 +#: ../../source/topics/deployment/serving_on_kuscia.rst:208 msgid "bool" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:193 +#: ../../source/topics/deployment/serving_on_kuscia.rst:208 msgid "Whether to enable round robin load balancer, Default: False" msgstr "是否开启轮询负载均衡访问,默认值:False。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:195 +#: ../../source/topics/deployment/serving_on_kuscia.rst:210 msgid "PartyConfig.featureSourceConfig.httpOpts.connectTimeoutMs" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:195 -#: ../../source/topics/deployment/serving_on_kuscia.rst:197 -#: ../../source/topics/deployment/serving_on_kuscia.rst:209 -#: ../../source/topics/deployment/serving_on_kuscia.rst:211 +#: ../../source/topics/deployment/serving_on_kuscia.rst:210 +#: ../../source/topics/deployment/serving_on_kuscia.rst:212 +#: ../../source/topics/deployment/serving_on_kuscia.rst:224 +#: ../../source/topics/deployment/serving_on_kuscia.rst:226 msgid "int32" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:195 -#: ../../source/topics/deployment/serving_on_kuscia.rst:211 +#: ../../source/topics/deployment/serving_on_kuscia.rst:210 +#: ../../source/topics/deployment/serving_on_kuscia.rst:226 msgid "Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms)" msgstr "连接超时时间,-1 即无限时间,默认值:500 (ms)" -#: ../../source/topics/deployment/serving_on_kuscia.rst:197 +#: ../../source/topics/deployment/serving_on_kuscia.rst:212 msgid "PartyConfig.featureSourceConfig.httpOpts.timeoutMs" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:197 +#: ../../source/topics/deployment/serving_on_kuscia.rst:212 msgid "" "Max duration of http request. -1 means wait indefinitely. Default: 1000 " "(ms)" msgstr "请求超时时间,-1 即无限时间,默认值:1000 (ms)" -#: ../../source/topics/deployment/serving_on_kuscia.rst:199 +#: ../../source/topics/deployment/serving_on_kuscia.rst:214 msgid "PartyConfig.featureSourceConfig.csvOpts" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:199 +#: ../../source/topics/deployment/serving_on_kuscia.rst:214 msgid ":ref:`CsvOptions `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:201 -msgid "PartyConfig.featureSourceConfig.csvOpts.file_path" +#: ../../source/topics/deployment/serving_on_kuscia.rst:216 +msgid "PartyConfig.featureSourceConfig.csvOpts.filePath" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:201 +#: ../../source/topics/deployment/serving_on_kuscia.rst:216 msgid "" "Input file path, specifies where to load data. Note that this will load " "all of the data into memory at once" msgstr "文件路径,注意:整个文件会被全部加载到内存中,不建议生产系统使用。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:203 -msgid "PartyConfig.featureSourceConfig.csvOpts.id_name" +#: ../../source/topics/deployment/serving_on_kuscia.rst:218 +msgid "PartyConfig.featureSourceConfig.csvOpts.idName" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:203 +#: ../../source/topics/deployment/serving_on_kuscia.rst:218 msgid "" -"Id column name, associated with `FeatureParam::query_datas`. " -"`query_datas` is a subset of id column" -msgstr "ID列名称,系统此时认为预测请求中`FeatureParam::query_datas`的内容为ID数据,将以此进行数据查询。" +"Id column name, associated with ``FeatureParam::query_datas``. " +"``query_datas`` is a subset of id column" +msgstr "ID列名称,系统此时认为预测请求中 ``FeatureParam::query_datas`` 的内容为ID数据,将以此进行数据查询。" -#: ../../source/topics/deployment/serving_on_kuscia.rst:205 +#: ../../source/topics/deployment/serving_on_kuscia.rst:220 msgid "PartyConfig.channelDesc" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:205 +#: ../../source/topics/deployment/serving_on_kuscia.rst:220 msgid ":ref:`ChannelDesc `" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:207 +#: ../../source/topics/deployment/serving_on_kuscia.rst:222 msgid "PartyConfig.channelDesc.protocol" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:207 +#: ../../source/topics/deployment/serving_on_kuscia.rst:222 msgid "" "Communication protocol, for optional value, see `here " "`_" @@ -479,14 +571,14 @@ msgstr "" "通信协议,可选值可以参考 `这里 " "`_" -#: ../../source/topics/deployment/serving_on_kuscia.rst:209 +#: ../../source/topics/deployment/serving_on_kuscia.rst:224 msgid "PartyConfig.channelDesc.rpcTimeoutMs" msgstr "" -#: ../../source/topics/deployment/serving_on_kuscia.rst:209 +#: ../../source/topics/deployment/serving_on_kuscia.rst:224 msgid "Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms)" msgstr "RPC超时时间, -1 即无穷时间。默认值:2000(ms)" -#: ../../source/topics/deployment/serving_on_kuscia.rst:211 +#: ../../source/topics/deployment/serving_on_kuscia.rst:226 msgid "PartyConfig.channelDesc.connectTimeoutMs" msgstr "" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po index a3e5287..f196f97 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/intro_to_graph.po @@ -8,14 +8,14 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-06-28 09:40+0000\n" +"POT-Creation-Date: 2024-08-14 20:59+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=utf-8\n" "Content-Transfer-Encoding: 8bit\n" -"Generated-By: Babel 2.14.0\n" +"Generated-By: Babel 2.15.0\n" #: ../../source/topics/graph/intro_to_graph.rst:4 msgid "Introduction to Graph" @@ -66,8 +66,8 @@ msgid "tag: Some properties of the operator." msgstr "tag: 算子的特定属性描述。" #: ../../source/topics/graph/intro_to_graph.rst:22 -msgid "attributes: Please check `Attributes` part below." -msgstr "attributes: 请参考下面 `属性` 的内容" +msgid "attributes: Please check ``Attributes`` part below." +msgstr "attributes: 请参考下面 ``属性`` 的内容" #: ../../source/topics/graph/intro_to_graph.rst:23 msgid "inputs and output: The info of the inputs or output of the operator." @@ -102,15 +102,15 @@ msgstr "type: 请参考 :doc:`AttrType ` 获取详细信息。 #: ../../source/topics/graph/intro_to_graph.rst:35 msgid "" -"is_optional: If True, when AttrValue is not provided, `default_value` " +"is_optional: If True, when AttrValue is not provided, ``default_value`` " "would be used. Else, AttrValue must be provided." msgstr "" -"is_optional: 为 True 时,如果对应的 AttrValue 没有设置,`default_value` 的值将会被使用。否则,对应的" +"is_optional: 为 True 时,如果对应的 AttrValue 没有设置,``default_value`` 的值将会被使用。否则,对应的" " AttrValue 必需提供。" #: ../../source/topics/graph/intro_to_graph.rst:36 -msgid "default_value: Please check :ref:`AttrValue `." -msgstr "default_value: 请参考 :doc:`AttrValue ` 获取详细信息。" +msgid "default_value: Please check :ref:`AttrValue `." +msgstr "default_value: 请参考 :doc:`AttrValue ` 获取详细信息。" #: ../../source/topics/graph/intro_to_graph.rst:39 msgid "Nodes" @@ -119,8 +119,8 @@ msgstr "节点" #: ../../source/topics/graph/intro_to_graph.rst:40 msgid "" "Nodes are instances of operators. They store the attribute values " -"(`AttrValue`) of the operators." -msgstr "节点是算子的实例。节点内包含算子的属性对应的属性值(`AttrValue`)" +"(``AttrValue``) of the operators." +msgstr "节点是算子的实例。节点内包含算子的属性对应的属性值(``AttrValue``)" #: ../../source/topics/graph/intro_to_graph.rst:43 msgid "NodeDef" @@ -181,8 +181,8 @@ msgid "node_list: The node list of the graph." msgstr "node_list: 图拥有的节点列表。" #: ../../source/topics/graph/intro_to_graph.rst:62 -msgid "execution_list: Please check `Executions` part below." -msgstr "execution_list: 请参考下面 `执行体` 的部分。" +msgid "execution_list: Please check ``Executions`` part below." +msgstr "execution_list: 请参考下面 ``执行体`` 的部分。" #: ../../source/topics/graph/intro_to_graph.rst:65 msgid "Executions" @@ -214,7 +214,7 @@ msgstr "nodes: 执行体中包含的节点列表。注意,这些节点应该 msgid "" "config: The runtime config of the execution. It describes the scheduling " "logic and session-related states of this execution unit. for more " -"details, please check :ref:`RuntimeConfig `." +"details, please check :ref:`RuntimeConfig `." msgstr "" "config: 执行体的运行配置。其描述执行体的调度逻辑以及会话状态。请查看 :ref:`RuntimeConfig " "` 获取更多信息。" @@ -232,7 +232,7 @@ msgid "" "lib/>`_ is a python library that provides interfaces to obtain " "Secretflow-Serving operators and export model files that Secretflow-" "Serving can load. For more details, please check :doc:`secretflow-" -"serving-lib docs `." +"serving-lib docs `." msgstr "" "因为Secretflow-Serving需要加载 `Secretflow " "`_ 训练后导出的模型,Secretflow-" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po index 572373e..72b5046 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-05-29 20:16+0800\n" +"POT-Creation-Date: 2024-08-12 14:42+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -22,27 +22,39 @@ msgid "SecretFlow-Serving Operator List" msgstr "SecretFlow-Serving 算子列表" #: ../../source/topics/graph/operator_list.md:9 -msgid "Last update: Wed May 29 20:14:58 2024" +msgid "Last update: Mon Aug 12 14:41:29 2024" msgstr "" #: ../../source/topics/graph/operator_list.md:10 -msgid "MERGE_Y" +msgid "PHE_2P_REDUCE" msgstr "" #: ../../source/topics/graph/operator_list.md:13 -msgid "Operator version: 0.0.3" +#: ../../source/topics/graph/operator_list.md:49 +#: ../../source/topics/graph/operator_list.md:89 +#: ../../source/topics/graph/operator_list.md:187 +#: ../../source/topics/graph/operator_list.md:220 +#: ../../source/topics/graph/operator_list.md:257 +#: ../../source/topics/graph/operator_list.md:293 +msgid "Operator version: 0.0.1" msgstr "" #: ../../source/topics/graph/operator_list.md:15 -msgid "Merge all partial y(score) and apply link function" +msgid "" +"Two-party computation operator. Select data encrypted by either our side " +"or the peer party according to the configuration." msgstr "" #: ../../source/topics/graph/operator_list.md:16 -#: ../../source/topics/graph/operator_list.md:55 -#: ../../source/topics/graph/operator_list.md:86 -#: ../../source/topics/graph/operator_list.md:123 +#: ../../source/topics/graph/operator_list.md:52 +#: ../../source/topics/graph/operator_list.md:92 +#: ../../source/topics/graph/operator_list.md:120 #: ../../source/topics/graph/operator_list.md:159 -#: ../../source/topics/graph/operator_list.md:195 +#: ../../source/topics/graph/operator_list.md:190 +#: ../../source/topics/graph/operator_list.md:223 +#: ../../source/topics/graph/operator_list.md:260 +#: ../../source/topics/graph/operator_list.md:296 +#: ../../source/topics/graph/operator_list.md:332 msgid "Attrs" msgstr "" @@ -67,41 +79,135 @@ msgid "Notes" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "exp_iters" +msgid "select_crypted_for_peer" msgstr "" #: ../../source/topics/graph/operator_list.md msgid "" -"Number of iterations of `exp` approximation, valid when `link_function` " -"set `LF_EXP_TAYLOR`" +"If `True`, select the data can be decrypted by peer, including self " +"calculated partial_y and peer's rand, otherwise select selfs." msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Integer32" +msgid "Boolean" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "N" +msgid "Y" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Default: 0." +msgid "rand_number_col_name" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "output_col_name" +msgid "The name of the rand number column in the input and output" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "The column name of merged score" +msgid "String" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "String" +msgid "partial_y_col_name" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Y" +msgid "The name of the partial_y column in the input and output" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:25 +#: ../../source/topics/graph/operator_list.md:64 +#: ../../source/topics/graph/operator_list.md:131 +#: ../../source/topics/graph/operator_list.md:233 +#: ../../source/topics/graph/operator_list.md:305 +#: ../../source/topics/graph/operator_list.md:343 +msgid "Tags" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "mergeable" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"The operator accept the output of operators with different participants " +"and will somehow merge them." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:32 +#: ../../source/topics/graph/operator_list.md:71 +#: ../../source/topics/graph/operator_list.md:100 +#: ../../source/topics/graph/operator_list.md:139 +#: ../../source/topics/graph/operator_list.md:170 +#: ../../source/topics/graph/operator_list.md:203 +#: ../../source/topics/graph/operator_list.md:240 +#: ../../source/topics/graph/operator_list.md:276 +#: ../../source/topics/graph/operator_list.md:312 +#: ../../source/topics/graph/operator_list.md:350 +msgid "Inputs" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "compute results" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The compute results from both self and peer's" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:39 +#: ../../source/topics/graph/operator_list.md:79 +#: ../../source/topics/graph/operator_list.md:107 +#: ../../source/topics/graph/operator_list.md:146 +#: ../../source/topics/graph/operator_list.md:177 +#: ../../source/topics/graph/operator_list.md:210 +#: ../../source/topics/graph/operator_list.md:247 +#: ../../source/topics/graph/operator_list.md:283 +#: ../../source/topics/graph/operator_list.md:319 +#: ../../source/topics/graph/operator_list.md:357 +msgid "Output" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "selected results" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The selected data" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:46 +msgid "PHE_2P_MERGE_Y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:51 +msgid "" +"Two-party computation operator. Merge the obfuscated partial_y decrypted " +"by the peer party with the partial_y based on self own key to obtain the " +"final prediction score." +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "exp_iters" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"Number of iterations of `exp` approximation, valid when `link_function` " +"set `LF_EXP_TAYLOR`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Integer32" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "N" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Default: 0." msgstr "" #: ../../source/topics/graph/operator_list.md @@ -119,14 +225,6 @@ msgid "" "LF_SIGMOID_SR, LF_SIGMOID_SEGLS" msgstr "" -#: ../../source/topics/graph/operator_list.md -msgid "input_col_name" -msgstr "" - -#: ../../source/topics/graph/operator_list.md -msgid "The column name of partial_y" -msgstr "" - #: ../../source/topics/graph/operator_list.md msgid "yhat_scale" msgstr "" @@ -146,11 +244,28 @@ msgstr "" msgid "Default: 1.0." msgstr "" -#: ../../source/topics/graph/operator_list.md:27 -#: ../../source/topics/graph/operator_list.md:96 -#: ../../source/topics/graph/operator_list.md:168 -#: ../../source/topics/graph/operator_list.md:206 -msgid "Tags" +#: ../../source/topics/graph/operator_list.md +msgid "score_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The name of the score column in the output" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "crypted_y_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The name of the crypted partial_y column in the second input" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "decrypted_y_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The name of the decrypted partial_y column in the first input" msgstr "" #: ../../source/topics/graph/operator_list.md @@ -162,39 +277,95 @@ msgid "The operator's output can be the final result" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "mergeable" +msgid "crypted_data" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The crypted data selected by `PHE_2P_REDUCE`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "decrypted_data" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The decrypted data output by `PHE_2P_DECRYPT_PEER_Y`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "score" msgstr "" #: ../../source/topics/graph/operator_list.md +msgid "The final linear predict score." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:86 +msgid "PHE_2P_DECRYPT_PEER_Y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:91 msgid "" -"The operator accept the output of operators with different participants " -"and will somehow merge them." +"Two-party computation operator. Decrypt the obfuscated partial_y and add " +"a random number." msgstr "" -#: ../../source/topics/graph/operator_list.md:35 -#: ../../source/topics/graph/operator_list.md:66 -#: ../../source/topics/graph/operator_list.md:103 -#: ../../source/topics/graph/operator_list.md:139 -#: ../../source/topics/graph/operator_list.md:175 -#: ../../source/topics/graph/operator_list.md:213 -msgid "Inputs" +#: ../../source/topics/graph/operator_list.md +msgid "decrypted_col_name" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "partial_ys" +msgid "The name of the decrypted result column in the output" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "The list of partial y, data type: `double`" +msgid "" +"The name of the partial_y(which can be decrypt by self) column in the " +"input" msgstr "" -#: ../../source/topics/graph/operator_list.md:42 -#: ../../source/topics/graph/operator_list.md:73 -#: ../../source/topics/graph/operator_list.md:110 -#: ../../source/topics/graph/operator_list.md:146 -#: ../../source/topics/graph/operator_list.md:182 -#: ../../source/topics/graph/operator_list.md:220 -msgid "Output" +#: ../../source/topics/graph/operator_list.md +msgid "Input feature table" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Decrypted partial_y with the added random number." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:114 +msgid "MERGE_Y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:117 +msgid "Operator version: 0.0.3" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:119 +msgid "Merge all partial y(score) and apply link function" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "output_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The column name of merged score" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "input_col_name" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The column name of partial_y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "partial_ys" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The list of partial y, data type: `double`" msgstr "" #: ../../source/topics/graph/operator_list.md @@ -205,16 +376,16 @@ msgstr "" msgid "The merge result of `partial_ys`, data type: `double`" msgstr "" -#: ../../source/topics/graph/operator_list.md:49 +#: ../../source/topics/graph/operator_list.md:153 msgid "DOT_PRODUCT" msgstr "" -#: ../../source/topics/graph/operator_list.md:52 -#: ../../source/topics/graph/operator_list.md:192 +#: ../../source/topics/graph/operator_list.md:156 +#: ../../source/topics/graph/operator_list.md:329 msgid "Operator version: 0.0.2" msgstr "" -#: ../../source/topics/graph/operator_list.md:54 +#: ../../source/topics/graph/operator_list.md:158 msgid "Calculate the dot product of feature weights and values" msgstr "" @@ -274,55 +445,126 @@ msgid "features" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Input feature table" +msgid "The calculation results, they have a data type of `double`." +msgstr "" + +#: ../../source/topics/graph/operator_list.md:184 +msgid "PHE_2P_DOT_PRODUCT" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:189 +msgid "" +"Two-party computation operator. Load the encrypted feature weights, " +"compute their dot product with the feature values, and add random noise " +"to the result for obfuscation. Only supports computation between two " +"parties, with the weights being encrypted using the other party's key." msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "The calculation results, they have a data type of `double`." +msgid "result_col_name" msgstr "" -#: ../../source/topics/graph/operator_list.md:80 -msgid "ARROW_PROCESSING" +#: ../../source/topics/graph/operator_list.md +msgid "The name of the calculation result(partial_y) column in the output" msgstr "" -#: ../../source/topics/graph/operator_list.md:83 -#: ../../source/topics/graph/operator_list.md:120 -#: ../../source/topics/graph/operator_list.md:156 -msgid "Operator version: 0.0.1" +#: ../../source/topics/graph/operator_list.md +msgid "offset_col_name" msgstr "" -#: ../../source/topics/graph/operator_list.md:85 -msgid "Replay secretflow compute functions" +#: ../../source/topics/graph/operator_list.md +msgid "The name of the offset column(feature) in the input" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "content_json_flag" +msgid "Default: ." msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Whether `trace_content` is serialized json" +msgid "The name of the generated rand number column in the output" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Boolean" +msgid "feature_types" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Default: False." +msgid "" +"List of input feature data types. Optional value: DT_UINT8, DT_INT8, " +"DT_UINT16, DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, DT_FLOAT, " +"DT_DOUBLE" msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "trace_content" +msgid "Default: []." msgstr "" #: ../../source/topics/graph/operator_list.md -msgid "Serialized data of secretflow compute trace" +msgid "feature_weights_ciphertext" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "feature weight ciphertext matrix bytes" msgstr "" #: ../../source/topics/graph/operator_list.md msgid "Bytes" msgstr "" +#: ../../source/topics/graph/operator_list.md +msgid "intercept_ciphertext" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Intercept ciphertext bytes or matrix bytes" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"List of feature names. Note that if there is an offset column, it needs " +"to be the last one in the list" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Input features" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "partial_y" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Calculation results" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:217 +msgid "ARROW_PROCESSING" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:222 +msgid "Replay secretflow compute functions" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "content_json_flag" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Whether `trace_content` is serialized json" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Default: False." +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "trace_content" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Serialized data of secretflow compute trace" +msgstr "" + #: ../../source/topics/graph/operator_list.md msgid "output_schema_bytes" msgstr "" @@ -347,11 +589,11 @@ msgstr "" msgid "output" msgstr "" -#: ../../source/topics/graph/operator_list.md:117 +#: ../../source/topics/graph/operator_list.md:254 msgid "TREE_SELECT" msgstr "" -#: ../../source/topics/graph/operator_list.md:122 +#: ../../source/topics/graph/operator_list.md:259 msgid "" "Obtaining the local prediction path information of the decision tree " "using input features." @@ -423,13 +665,6 @@ msgstr "" msgid "input_feature_types" msgstr "" -#: ../../source/topics/graph/operator_list.md -msgid "" -"List of input feature data types. Optional value: DT_UINT8, DT_INT8, " -"DT_UINT16, DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, DT_FLOAT, " -"DT_DOUBLE" -msgstr "" - #: ../../source/topics/graph/operator_list.md msgid "node_ids" msgstr "" @@ -450,11 +685,11 @@ msgstr "" msgid "The local prediction path information of the decision tree." msgstr "" -#: ../../source/topics/graph/operator_list.md:153 +#: ../../source/topics/graph/operator_list.md:290 msgid "TREE_MERGE" msgstr "" -#: ../../source/topics/graph/operator_list.md:158 +#: ../../source/topics/graph/operator_list.md:295 msgid "" "Merge the `TREE_SELECT` output from multiple parties to obtain a unique " "prediction path and return the result weights." @@ -470,10 +705,6 @@ msgid "" "attr can be omitted." msgstr "" -#: ../../source/topics/graph/operator_list.md -msgid "Default: []." -msgstr "" - #: ../../source/topics/graph/operator_list.md msgid "The column name of tree predict score" msgstr "" @@ -490,19 +721,15 @@ msgstr "" msgid "Input tree selects" msgstr "" -#: ../../source/topics/graph/operator_list.md -msgid "score" -msgstr "" - #: ../../source/topics/graph/operator_list.md msgid "The prediction result of tree." msgstr "" -#: ../../source/topics/graph/operator_list.md:189 +#: ../../source/topics/graph/operator_list.md:326 msgid "TREE_ENSEMBLE_PREDICT" msgstr "" -#: ../../source/topics/graph/operator_list.md:194 +#: ../../source/topics/graph/operator_list.md:331 msgid "" "Accept the weighted results from multiple trees (`TREE_SELECT` + " "`TREE_MERGE`), merge them, and obtain the final prediction result of the " diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/system/feature_service.po b/docs/locales/zh_CN/LC_MESSAGES/topics/system/feature_service.po index f00ece6..4e6cbde 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/system/feature_service.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/system/feature_service.po @@ -8,14 +8,14 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-04-16 20:33+0800\n" +"POT-Creation-Date: 2024-08-14 20:59+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=utf-8\n" "Content-Transfer-Encoding: 8bit\n" -"Generated-By: Babel 2.14.0\n" +"Generated-By: Babel 2.15.0\n" #: ../../source/topics/system/feature_service.rst:3 msgid "Feature Service" @@ -271,7 +271,7 @@ msgid "enum(ErrorCode)" msgstr "" #: ../../source/topics/system/feature_service.rst:128 -msgid "Value reference `ErrorCode`" +msgid "Value reference ``ErrorCode``" msgstr "" #: ../../source/topics/system/feature_service.rst:130 @@ -346,7 +346,7 @@ msgstr "" #: ../../source/topics/system/feature_service.rst:148 #: ../../source/topics/system/feature_service.rst:150 #: ../../source/topics/system/feature_service.rst:152 -msgid "No(assign the corresponding value list based on `field.type`)" +msgid "No(assign the corresponding value list based on ``field.type``)" msgstr "" #: ../../source/topics/system/feature_service.rst:144 diff --git a/docs/source/reference/model.md b/docs/source/reference/model.md index da712c1..77099f8 100644 --- a/docs/source/reference/model.md +++ b/docs/source/reference/model.md @@ -54,6 +54,8 @@ - [ExecutionDef](#executiondef) - [GraphDef](#graphdef) - [GraphView](#graphview) + - [HeConfig](#heconfig) + - [HeInfo](#heinfo) - [NodeDef](#nodedef) - [NodeDef.AttrValuesEntry](#nodedef-attrvaluesentry) - [NodeView](#nodeview) @@ -323,7 +325,7 @@ Representation operator property | ----- | ---- | ----------- | | returnable | [ bool](#bool ) | The operator's output can be the final result | | mergeable | [ bool](#bool ) | The operator accept the output of operators with different participants and will somehow merge them. | -| session_run | [ bool](#bool ) | The operator needs to be executed in session. | +| session_run | [ bool](#bool ) | The operator needs to be executed in session. TODO: not supported yet. | | variable_inputs | [ bool](#bool ) | Whether this op has variable input argument. default `false`. | @@ -356,6 +358,8 @@ and a set of executions that describes the scheduling of the graph. | version | [ string](#string ) | Version of the graph | | node_list | [repeated NodeDef](#nodedef ) | none | | execution_list | [repeated ExecutionDef](#executiondef ) | none | +| he_config | [ HeConfig](#heconfig ) | none | +| party_num | [ int32](#int32 ) | none | @@ -371,6 +375,35 @@ only structural information and excluding the data components. | version | [ string](#string ) | Version of the graph | | node_list | [repeated NodeView](#nodeview ) | none | | execution_list | [repeated ExecutionDef](#executiondef ) | none | +| he_info | [ HeInfo](#heinfo ) | none | +| party_num | [ int32](#int32 ) | none | + + + + + +### HeConfig +The config for HE compute. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| pk_buf | [ bytes](#bytes ) | Serialized public key bytes | +| sk_buf | [ bytes](#bytes ) | Serialized secret key bytes | +| encode_scale | [ int64](#int64 ) | Encode scale for data | + + + + + +### HeInfo +The public info for HE compute. + + +| Field | Type | Description | +| ----- | ---- | ----------- | +| pk_buf | [ bytes](#bytes ) | Serialized public key bytes | +| encode_scale | [ int64](#int64 ) | Encode scale for data | @@ -684,6 +717,8 @@ Supported dispatch type | DP_ALL | 1 | Dispatch all participants. | | DP_ANYONE | 2 | Dispatch any participant. | | DP_SPECIFIED | 3 | Dispatch specified participant. | +| DP_SELF | 4 | Dispatch self. | +| DP_PEER | 12 | For 2-parties, Dispatch peer participant. | diff --git a/docs/source/topics/deployment/serving_on_kuscia.rst b/docs/source/topics/deployment/serving_on_kuscia.rst index 2b0b7ba..5bf9511 100644 --- a/docs/source/topics/deployment/serving_on_kuscia.rst +++ b/docs/source/topics/deployment/serving_on_kuscia.rst @@ -22,8 +22,15 @@ To deploy SecretFlow-Serving in Kusica, you first need to register the template "serving_id": "{{.SERVING_ID}}", "input_config": "{{.INPUT_CONFIG}}", "cluster_def": "{{.CLUSTER_DEFINE}}", - "allocated_ports": "{{.ALLOCATED_PORTS}}" + "allocated_ports": "{{.ALLOCATED_PORTS}}", + "oss_meta": "{{.SERVING_OSS_META}}", + "spi_tls_config": "{{.SERVING_SPI_TLS_CONFIG}}", + "http_source_meta": "{{.SERVING_HTTP_SOURCE_META}}" } + logging-config.conf: | + {{.SERVING_LOGGING_CONFIG}} + trace-config.conf: | + {{.SERVING_TRACE_CONFIG}} deployTemplates: - name: secretflow replicas: 1 @@ -32,10 +39,14 @@ To deploy SecretFlow-Serving in Kusica, you first need to register the template - command: - sh - -c - - ./secretflow_serving --flagfile=conf/gflags.conf --config_mode=kuscia --serving_config_file=/etc/kuscia/serving-config.conf + - ./secretflow_serving --flagfile=conf/gflags.conf --config_mode=kuscia --serving_config_file=/etc/kuscia/serving-config.conf --logging_config_file=/etc/kuscia/logging-config.conf --trace_config_file=/etc/kuscia/trace-config.conf configVolumeMounts: - mountPath: /etc/kuscia/serving-config.conf subPath: serving-config.conf + - mountPath: /etc/kuscia/logging-config.conf + subPath: logging-config.conf + - mountPath: /etc/kuscia/trace-config.conf + subPath: trace-config.conf name: secretflow ports: - name: service @@ -75,17 +86,23 @@ The explanation of the common fields can be found `here `. The current content is a placeholder and will actually be replaced by the content in `Kuscia API /api/v1/serving/create `_ at startup. - * `input_config`: SecretFlow-Serving startup configuration, details can be seen in the description below. The current content is a placeholder and will actually be replaced by the content in `Kuscia API /api/v1/serving/create `_ at startup. - * `cluster_def`: See `AppImage explanation `_. - * `allocated_ports`: See `AppImage explanation `_. - -* `ports`: - * `service`: The :ref:`ServerConfig.service_port ` - * `communication`: The :ref:`ServerConfig.communication_port ` - * `internal`: The :ref:`ServerConfig.metrics_exposer_port ` - * `brpc-builtin`: The :ref:`ServerConfig.brpc_builtin_service_port ` +* **configTemplates**: + * **serving-config.conf**: + * **serving_id**: Service ID identifier, corresponding to the configuration :ref:`ServingConfig.id `. The current content is a placeholder and will actually be replaced by the content in `Kuscia API /api/v1/serving/create `_ at startup. + * **input_config**: SecretFlow-Serving startup configuration, details can be seen in the description below. The current content is a placeholder and will actually be replaced by the content in `Kuscia API /api/v1/serving/create `_ at startup. + * **cluster_def**: See `AppImage explanation `_. + * **allocated_ports**: See `AppImage explanation `_. + * **oss_meta**: OSS/S3 model source configuration, only effective when using OSS/S3 as the model data source. The actual content is in the form of a string-formatted JSON configuration, for example" ``{\"access_key\":\"test_ak\", \"secret_key\":\"test_sk\", \"virtual_hosted\":true, \"endpoint\":\"test_endpoint\", \"bucket\":\"test_bucket\"}``, the definition can be found :ref:`here `. This is an optional configuration and can be set up through the ``Kuscia configuration management`` system if needed. + * **spi_tls_config**: The TLS configuration used by SPI. The actual content is in the form of a string-formatted JSON configuration, for example: ``{\"certificate_path\":\"abc\", \"private_key_path\":\"def\",\"ca_file_path\":\"gkh\"}``, the definition can be found :ref:`here `. This is an optional configuration and can be set up through the ``Kuscia configuration management`` system if needed. + * **http_source_meta**: HTTP model source configuration, only effective when using HTTP as the model data source. The actual content is in the form of a string-formatted JSON configuration, for example" ``{\"connectTimeoutMs\":60000,\"timeoutMs\":120000,\"tlsConfig\":{\"certificatePath\":\"abc\", \"privateKeyPath\":\"def\",\"caFilePath\":\"gkh\"}}``, the definition can be found :ref:`here `. This is an optional configuration and can be set up through the ``Kuscia configuration management`` system if needed. + * **logging-config.conf**: SecretFlow-Serving Log Configuration. The actual content is in the form of a string-formatted JSON configuration, for example" ``{\"systemLogPath\":\"/tmp/alice/serving.log\",\"logLevel\":\"INFO_LOG_LEVEL\",\"maxLogFileSize\":4194304,\"maxLogFileCount\":10}``, the definition can be found :ref:`here `. This is an optional configuration and can be set up through the ``Kuscia configuration management`` system if needed. + * **trace-config.conf**: SecretFlow-Serving Trace Configuration. The actual content is in the form of a string-formatted JSON configuration, for example" ``{\"traceLogEnable\":true,\"traceLogConf\":{\"traceLogPath\":\"/tmp/trace.log\"}}``, the definition can be found :ref:`here `. This is an optional configuration and can be set up through the ``Kuscia configuration management`` system if needed. + +* **ports**: + * **service**: The :ref:`ServerConfig.service_port ` + * **communication**: The :ref:`ServerConfig.communication_port ` + * **internal**: The :ref:`ServerConfig.metrics_exposer_port ` + * **brpc-builtin**: The :ref:`ServerConfig.brpc_builtin_service_port ` Configuration description ========================= @@ -93,7 +110,7 @@ Configuration description serving_input_config -------------------- -The launch and management of SecretFlow-Serving can be performed using the `Kuscia Serving API `_. In this section, we will explain the contents of the `serving_input_config` field within the `/api/v1/serving/create` request. +The launch and management of SecretFlow-Serving can be performed using the `Kuscia Serving API `_. In this section, we will explain the contents of the ``serving_input_config`` field within the ``/api/v1/serving/create`` request. .. code-block:: json @@ -118,7 +135,7 @@ The launch and management of SecretFlow-Serving can be performed using the `Kusc "featureSourceConfig": { "mockOpts": {} }, - "channel_desc": { + "channelDesc": { "protocol": "http" } }, @@ -141,7 +158,7 @@ The launch and management of SecretFlow-Serving can be performed using the `Kusc "featureSourceConfig": { "mockOpts": {} }, - "channel_desc": { + "channelDesc": { "protocol": "http" } } @@ -150,59 +167,61 @@ The launch and management of SecretFlow-Serving can be performed using the `Kusc **Field description**: -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| Name | Type | Description | Required | -+===========================================================+=======================+=================================================================================================================================================+========================================================================+ -| partyConfigs | map | Dictionary of startup parameters for each participant. Key: Participant Unique ID; Value: PartyConfig (Json Object). | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.serverConfig | str | :ref:`ServerConfig ` | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.serverConfig.featureMapping | map | Feature name mapping rules. Key: source or predefined feature name; Value: model feature name | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig | Object | :ref:`ModelConfig ` | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.modelId | str | Unique id of the model package | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.basePath | str | The local path used to cache and load model package | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.sourcePath | str | The path to the model package in the data source, where the content format may vary depending on the `sourceType`. | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.source_sha256 | str | The expected SHA256 hash of the model package. When provided, the fetched model package will be verified against it. | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.sourceType | str | Model data source type, options include: ST_FILE: In this case, the sourcePath should be a file path accessible to Serving. | Yes | -| | | ST_DP: In this case, the sourcePath should be DomainData ID in DataMesh from Kuscia. and dpSourceMeta needs to be configured. | | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.dpSourceMeta | Object | :ref:`DPSourceMeta ` | No(If `sourceType` is `DT_DP`, `dpSourceMeta` needs to be configured) | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.dpSourceMeta.dmHost | str | The address of DataMesh in Kuscia. Default: datamesh:8071 | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig | Object | :ref:`FeatureSourceConfig ` | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.mockOpts | Object | :ref:`MockOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.mockOpts.type | str | The method for generating mock feature values, options: "MDT_RANDOM" for random values, and "MDT_FIXED" for fixed values. Default: "MDT_FIXED". | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts | Object | :ref:`HttpOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.endpoint | str | Feature service address | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.enableLb | bool | Whether to enable round robin load balancer, Default: False | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.timeoutMs | int32 | Max duration of http request. -1 means wait indefinitely. Default: 1000 (ms) | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.csvOpts | Object | :ref:`CsvOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.csvOpts.file_path | Object | Input file path, specifies where to load data. Note that this will load all of the data into memory at once | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.csvOpts.id_name | Object | Id column name, associated with `FeatureParam::query_datas`. `query_datas` is a subset of id column | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc | Object | :ref:`ChannelDesc ` | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc.protocol | str | Communication protocol, for optional value, see `here `_ | Yes | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc.rpcTimeoutMs | int32 | Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms) | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | -+-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| Name | Type | Description | Required | ++===========================================================+=======================+=========================================================================================================================================================================================+==============================================================================+ +| partyConfigs | map | Dictionary of startup parameters for each participant. Key: Participant Unique ID; Value: PartyConfig (Json Object). | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.serverConfig | str | :ref:`ServerConfig ` | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.serverConfig.featureMapping | map | Feature name mapping rules. Key: source or predefined feature name; Value: model feature name | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig | Object | :ref:`ModelConfig ` | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.modelId | str | Unique id of the model package | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.basePath | str | The local path used to cache and load model package | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.sourcePath | str | The path to the model package in the data source, where the content format may vary depending on the ``sourceType``. | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.sourceSha256 | str | The expected SHA256 hash of the model package. When provided, the fetched model package will be verified against it. | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.sourceType | str | Model data source type, options include: | | +| | | ``ST_FILE``: In this case, the ``sourcePath`` should be a file path accessible to Serving. | Yes | +| | | ``ST_DP``: In this case, the ``sourcePath`` should be DomainData ID in DataMesh from Kuscia. and dpSourceMeta needs to be configured. | | +| | | ``ST_OSS``: In this case, the ``sourcePath`` should be the file path within the bucket. ``ST_HTTP``: In this case, the ``sourcePath`` should be the download URL for the model package. | | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.dpSourceMeta | Object | :ref:`DPSourceMeta ` | No(If ``sourceType`` is ``DT_DP``, ``dpSourceMeta`` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.modelConfig.dpSourceMeta.dmHost | str | The address of DataMesh in Kuscia. Default: datamesh:8071 | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig | Object | :ref:`FeatureSourceConfig ` | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.mockOpts | Object | :ref:`MockOptions ` | No(One of ``csvOpts``, ``mockOpts``, or ``httpOpts`` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.mockOpts.type | str | The method for generating mock feature values, options: "MDT_RANDOM" for random values, and "MDT_FIXED" for fixed values. Default: "MDT_FIXED". | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts | Object | :ref:`HttpOptions ` | No(One of ``csvOpts``, ``mockOpts``, or ``httpOpts`` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.endpoint | str | Feature service address | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.enableLb | bool | Whether to enable round robin load balancer, Default: False | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.timeoutMs | int32 | Max duration of http request. -1 means wait indefinitely. Default: 1000 (ms) | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.csvOpts | Object | :ref:`CsvOptions ` | No(One of ``csvOpts``, ``mockOpts``, or ``httpOpts`` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.csvOpts.filePath | Object | Input file path, specifies where to load data. Note that this will load all of the data into memory at once | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.csvOpts.idName | Object | Id column name, associated with ``FeatureParam::query_datas``. ``query_datas`` is a subset of id column | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.channelDesc | Object | :ref:`ChannelDesc ` | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.channelDesc.protocol | str | Communication protocol, for optional value, see `here `_ | Yes | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.channelDesc.rpcTimeoutMs | int32 | Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms) | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ +| PartyConfig.channelDesc.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | ++-----------------------------------------------------------+-----------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------+ diff --git a/docs/source/topics/graph/intro_to_graph.rst b/docs/source/topics/graph/intro_to_graph.rst index e9e7d31..aedfa15 100644 --- a/docs/source/topics/graph/intro_to_graph.rst +++ b/docs/source/topics/graph/intro_to_graph.rst @@ -19,7 +19,7 @@ OpDef * desc: Description of the operator. * version: The version of the operator. * tag: Some properties of the operator. -* attributes: Please check `Attributes` part below. +* attributes: Please check ``Attributes`` part below. * inputs and output: The info of the inputs or output of the operator. Attributes @@ -32,12 +32,12 @@ AttrDef * name: Must be unique among all attrs of the operator. * desc: Description of the attribute. * type: Please check :ref:`AttrType `. -* is_optional: If True, when AttrValue is not provided, `default_value` would be used. Else, AttrValue must be provided. -* default_value: Please check :ref:`AttrValue `. +* is_optional: If True, when AttrValue is not provided, ``default_value`` would be used. Else, AttrValue must be provided. +* default_value: Please check :ref:`AttrValue `. Nodes ----- -Nodes are instances of operators. They store the attribute values (`AttrValue`) of the operators. +Nodes are instances of operators. They store the attribute values (``AttrValue``) of the operators. NodeDef ^^^^^^^ @@ -59,7 +59,7 @@ GraphDef * version: Version of the graph. * node_list: The node list of the graph. -* execution_list: Please check `Executions` part below. +* execution_list: Please check ``Executions`` part below. Executions ---------- @@ -72,7 +72,7 @@ ExecutionDef ^^^^^^^^^^^^ * nodes: Represents the nodes contained in this execution. Note that these node names should be findable and unique within the node definitions. One node can only exist in one execution and must exist in one. -* config: The runtime config of the execution. It describes the scheduling logic and session-related states of this execution unit. for more details, please check :ref:`RuntimeConfig `. +* config: The runtime config of the execution. It describes the scheduling logic and session-related states of this execution unit. for more details, please check :ref:`RuntimeConfig `. Secretflow Serving Library ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -81,4 +81,4 @@ Since Secretflow-Serving Load models trained by `Secretflow `. So `secretflow-serving-lib `_ is a python library that provides interfaces to obtain Secretflow-Serving operators and export model files that Secretflow-Serving can load. -For more details, please check :doc:`secretflow-serving-lib docs `. +For more details, please check :doc:`secretflow-serving-lib docs `. diff --git a/docs/source/topics/graph/operator_list.md b/docs/source/topics/graph/operator_list.md index 95c91a9..cb02321 100644 --- a/docs/source/topics/graph/operator_list.md +++ b/docs/source/topics/graph/operator_list.md @@ -6,7 +6,111 @@ SecretFlow-Serving Operator List ================================ -Last update: Wed May 29 20:14:58 2024 +Last update: Mon Aug 12 14:41:29 2024 +## PHE_2P_REDUCE + + +Operator version: 0.0.1 + +Two-party computation operator. Select data encrypted by either our side or the peer party according to the configuration. +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|select_crypted_for_peer|If `True`, select the data can be decrypted by peer, including self calculated partial_y and peer's rand, otherwise select selfs.|Boolean|Y|| +|rand_number_col_name|The name of the rand number column in the input and output|String|Y|| +|partial_y_col_name|The name of the partial_y column in the input and output|String|Y|| + +### Tags + + +|Name|Description| +| :--- | :--- | +|mergeable|The operator accept the output of operators with different participants and will somehow merge them.| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|compute results|The compute results from both self and peer's| + +### Output + + +|Name|Description| +| :--- | :--- | +|selected results|The selected data| + +## PHE_2P_MERGE_Y + + +Operator version: 0.0.1 + +Two-party computation operator. Merge the obfuscated partial_y decrypted by the peer party with the partial_y based on self own key to obtain the final prediction score. +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|exp_iters|Number of iterations of `exp` approximation, valid when `link_function` set `LF_EXP_TAYLOR`|Integer32|N|Default: 0.| +|link_function|Type of link function, defined in `secretflow_serving/protos/link_function.proto`. Optional value: LF_EXP, LF_EXP_TAYLOR, LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS|String|Y|| +|yhat_scale|In order to prevent value overflow, GLM training is performed on the scaled y label. So in the prediction process, you need to enlarge yhat back to get the real predicted value, `yhat = yhat_scale * link(X * W)`|Double|N|Default: 1.0.| +|score_col_name|The name of the score column in the output|String|Y|| +|crypted_y_col_name|The name of the crypted partial_y column in the second input|String|Y|| +|decrypted_y_col_name|The name of the decrypted partial_y column in the first input|String|Y|| + +### Tags + + +|Name|Description| +| :--- | :--- | +|returnable|The operator's output can be the final result| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|crypted_data|The crypted data selected by `PHE_2P_REDUCE`| +|decrypted_data|The decrypted data output by `PHE_2P_DECRYPT_PEER_Y`| + +### Output + + +|Name|Description| +| :--- | :--- | +|score|The final linear predict score.| + +## PHE_2P_DECRYPT_PEER_Y + + +Operator version: 0.0.1 + +Two-party computation operator. Decrypt the obfuscated partial_y and add a random number. +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|decrypted_col_name|The name of the decrypted result column in the output|String|Y|| +|partial_y_col_name|The name of the partial_y(which can be decrypt by self) column in the input|String|Y|| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|crypted_data|Input feature table| + +### Output + + +|Name|Description| +| :--- | :--- | +|decrypted_data|Decrypted partial_y with the added random number.| + ## MERGE_Y @@ -77,6 +181,39 @@ Calculate the dot product of feature weights and values | :--- | :--- | |partial_ys|The calculation results, they have a data type of `double`.| +## PHE_2P_DOT_PRODUCT + + +Operator version: 0.0.1 + +Two-party computation operator. Load the encrypted feature weights, compute their dot product with the feature values, and add random noise to the result for obfuscation. Only supports computation between two parties, with the weights being encrypted using the other party's key. +### Attrs + + +|Name|Description|Type|Required|Notes| +| :--- | :--- | :--- | :--- | :--- | +|result_col_name|The name of the calculation result(partial_y) column in the output|String|Y|| +|offset_col_name|The name of the offset column(feature) in the input|String|N|Default: .| +|rand_number_col_name|The name of the generated rand number column in the output|String|Y|| +|feature_types|List of input feature data types. Optional value: DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, DT_FLOAT, DT_DOUBLE|String List|N|Default: [].| +|feature_weights_ciphertext|feature weight ciphertext matrix bytes|Bytes|N|| +|intercept_ciphertext|Intercept ciphertext bytes or matrix bytes|Bytes|N|| +|feature_names|List of feature names. Note that if there is an offset column, it needs to be the last one in the list|String List|N|Default: [].| + +### Inputs + + +|Name|Description| +| :--- | :--- | +|features|Input features| + +### Output + + +|Name|Description| +| :--- | :--- | +|partial_y|Calculation results| + ## ARROW_PROCESSING diff --git a/docs/source/topics/system/feature_service.rst b/docs/source/topics/system/feature_service.rst index 29a3a6f..e3e8daf 100644 --- a/docs/source/topics/system/feature_service.rst +++ b/docs/source/topics/system/feature_service.rst @@ -116,41 +116,41 @@ Example of response body: **Field description**: -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| Name | Type | Description | Required | -+=========================+======================+===================================+===============================================================+ -| header | Object(Header) | Custom data | No | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| header.data | Map | Key:str, Value:str | No | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| status | Object(Status) | The Status of this response | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| status.code | enum(ErrorCode) | Value reference `ErrorCode` | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| status.msg | str | The detail message of the status | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features | List | The Request feature data list | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].field | Object(FeatureField) | The definition of a feature field | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].field.name | str | Unique name of the feature | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].field.type | enum(FieldType) | Field type of the feature | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value | Object(FeatureValue) | The definition of a feature value | Yes | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value.i32s[] | List | int32 feature value data list | No(assign the corresponding value list based on `field.type`) | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value.i64s[] | List | int64 feature value data list | No(assign the corresponding value list based on `field.type`) | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value.fs[] | List | float feature value data list | No(assign the corresponding value list based on `field.type`) | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value.ds[] | List | double feature value data list | No(assign the corresponding value list based on `field.type`) | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value.ss[] | List | string feature value data list | No(assign the corresponding value list based on `field.type`) | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ -| features[].value.bs[] | List | bool feature value data list | No(assign the corresponding value list based on `field.type`) | -+-------------------------+----------------------+-----------------------------------+---------------------------------------------------------------+ ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| Name | Type | Description | Required | ++=========================+======================+===================================+=================================================================+ +| header | Object(Header) | Custom data | No | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| header.data | Map | Key:str, Value:str | No | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| status | Object(Status) | The Status of this response | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| status.code | enum(ErrorCode) | Value reference ``ErrorCode`` | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| status.msg | str | The detail message of the status | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features | List | The Request feature data list | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].field | Object(FeatureField) | The definition of a feature field | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].field.name | str | Unique name of the feature | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].field.type | enum(FieldType) | Field type of the feature | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value | Object(FeatureValue) | The definition of a feature value | Yes | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value.i32s[] | List | int32 feature value data list | No(assign the corresponding value list based on ``field.type``) | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value.i64s[] | List | int64 feature value data list | No(assign the corresponding value list based on ``field.type``) | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value.fs[] | List | float feature value data list | No(assign the corresponding value list based on ``field.type``) | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value.ds[] | List | double feature value data list | No(assign the corresponding value list based on ``field.type``) | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value.ss[] | List | string feature value data list | No(assign the corresponding value list based on ``field.type``) | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ +| features[].value.bs[] | List | bool feature value data list | No(assign the corresponding value list based on ``field.type``) | ++-------------------------+----------------------+-----------------------------------+-----------------------------------------------------------------+ Step 2: Configure startup config diff --git a/python_lib/secretflow_serving_lib/BUILD.bazel b/python_lib/secretflow_serving_lib/BUILD.bazel index 24e2dd9..9182b24 100644 --- a/python_lib/secretflow_serving_lib/BUILD.bazel +++ b/python_lib/secretflow_serving_lib/BUILD.bazel @@ -89,5 +89,6 @@ py_library( "version.py", ":api", ":protos", + "//python_lib/secretflow_serving_lib/config", ], ) diff --git a/python_lib/secretflow_serving_lib/__init__.py b/python_lib/secretflow_serving_lib/__init__.py index 7bf948c..ab9b964 100644 --- a/python_lib/secretflow_serving_lib/__init__.py +++ b/python_lib/secretflow_serving_lib/__init__.py @@ -22,6 +22,7 @@ from . import bundle_pb2 from . import data_type_pb2 from . import link_function_pb2 +from . import config from .api import get_all_ops, get_op, get_graph_version from .graph_builder import GraphBuilder, check_graph_views, build_serving_tar @@ -42,4 +43,5 @@ "link_function_pb2", "build_serving_tar", "GraphBuilder", + "config", ] diff --git a/python_lib/secretflow_serving_lib/config/BUILD.bazel b/python_lib/secretflow_serving_lib/config/BUILD.bazel new file mode 100644 index 0000000..e8a6a5a --- /dev/null +++ b/python_lib/secretflow_serving_lib/config/BUILD.bazel @@ -0,0 +1,25 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "config", + srcs = [ + "__init__.py", + "//secretflow_serving/config:config_py_proto", + "//secretflow_serving/tools/inferencer:inference_config_py_proto", + ], +) diff --git a/python_lib/secretflow_serving_lib/config/__init__.py b/python_lib/secretflow_serving_lib/config/__init__.py new file mode 100644 index 0000000..8b593b8 --- /dev/null +++ b/python_lib/secretflow_serving_lib/config/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from secretflow_serving.config.serving_config_pb2 import * +from secretflow_serving.config.cluster_config_pb2 import * +from secretflow_serving.config.feature_config_pb2 import * +from secretflow_serving.config.model_config_pb2 import * +from secretflow_serving.config.server_config_pb2 import * +from secretflow_serving.config.tls_config_pb2 import * +from secretflow_serving.config.logging_config_pb2 import * +from secretflow_serving.config.trace_config_pb2 import * +from secretflow_serving.tools.inferencer.config_pb2 import * diff --git a/python_lib/secretflow_serving_lib/graph_builder.py b/python_lib/secretflow_serving_lib/graph_builder.py index ef1b440..088a2e0 100644 --- a/python_lib/secretflow_serving_lib/graph_builder.py +++ b/python_lib/secretflow_serving_lib/graph_builder.py @@ -31,6 +31,7 @@ GraphView, NodeDef, RuntimeConfig, + HeConfig, ) from .op_pb2 import OpDef @@ -149,6 +150,9 @@ class _GraphProtoWrapper: def __init__(self, version: str): self.version = version self.executions = [] + self.pk_bytes = None + self.sk_bytes = None + self.scale = None def add_execution(self, exec: _ExecutionProtoWrapper): self.executions.append(exec) @@ -159,9 +163,17 @@ def get_execution(self, idx: int) -> _ExecutionProtoWrapper: def get_execution_count(self): return len(self.executions) + def set_he_config(self, pk_bytes, sk_bytes, scale): + self.pk_bytes = pk_bytes + self.sk_bytes = sk_bytes + self.scale = scale + def proto(self) -> GraphDef: return GraphDef( version=self.version, + he_config=HeConfig( + pk_buf=self.pk_bytes, sk_buf=self.sk_bytes, encode_scale=self.scale + ), execution_list=[exec.proto() for exec in self.executions], node_list=list( chain.from_iterable( @@ -228,6 +240,9 @@ def begin_new_execution( ) ) + def set_he_config(self, pk_bytes, sk_bytes, scale): + self.graph.set_he_config(pk_bytes, sk_bytes, scale) + def build_proto(self) -> GraphDef: '''Get the GraphDef include all nodes and executions''' graph_def_str = libserving.graph_validator_impl( diff --git a/python_lib/setup.py b/python_lib/setup.py index d9c50f0..6c28913 100644 --- a/python_lib/setup.py +++ b/python_lib/setup.py @@ -97,6 +97,8 @@ def get_packages(self): generated_python_directories = [ "../bazel-bin/python_lib", "../bazel-bin/secretflow_serving/protos", + "../bazel-bin/secretflow_serving/config", + "../bazel-bin/secretflow_serving/tools/inferencer", ] setup_spec.install_requires = read_requirements("requirements.txt") files_to_remove = [] @@ -107,6 +109,8 @@ def get_packages(self): "../bazel-bin/python_lib/secretflow_serving_lib/libserving" + pyd_suffix ] +serving_tools_files = ["../bazel-bin/secretflow_serving/tools/inferencer/inferencer"] + # Calls Bazel in PATH def bazel_invoke(invoker, cmdline, *args, **kwargs): @@ -117,7 +121,7 @@ def bazel_invoke(invoker, cmdline, *args, **kwargs): raise -def build(build_python, build_cpp): +def build(): if tuple(sys.version_info[:2]) not in SUPPORTED_PYTHONS: msg = ( "Detected Python version {}, which is not supported. " @@ -140,11 +144,12 @@ def build(build_python, build_cpp): bazel_precmd_flags = [] - bazel_targets = [] - bazel_targets += ( - ["//python_lib/secretflow_serving_lib:init"] if build_python else [] - ) - bazel_targets += ["//python_lib/secretflow_serving_lib:api"] if build_cpp else [] + bazel_targets = [ + "//python_lib/secretflow_serving_lib:init", + "//python_lib/secretflow_serving_lib:api", + "//python_lib/secretflow_serving_lib/config", + "//secretflow_serving/tools/inferencer", + ] bazel_flags.extend(["-c", "opt"]) @@ -191,9 +196,10 @@ def remove_file(target_dir, filename): def pip_run(build_ext): - build(True, True) + build() setup_spec.files_to_include += serving_ops_lib_files + setup_spec.files_to_include += serving_tools_files # Copy over the autogenerated protobuf Python bindings. for directory in generated_python_directories: diff --git a/secretflow_serving/config/BUILD.bazel b/secretflow_serving/config/BUILD.bazel index 09869fa..4613954 100644 --- a/secretflow_serving/config/BUILD.bazel +++ b/secretflow_serving/config/BUILD.bazel @@ -14,6 +14,7 @@ load("@rules_cc//cc:defs.bzl", "cc_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile") package(default_visibility = ["//visibility:public"]) @@ -126,3 +127,20 @@ cc_proto_library( name = "serving_config_cc_proto", deps = [":serving_config_proto"], ) + +python_proto_compile( + name = "config_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = [ + ":cluster_config_proto", + ":feature_config_proto", + ":logging_config_proto", + ":model_config_proto", + ":retry_policy_config_proto", + ":server_config_proto", + ":serving_config_proto", + ":tls_config_proto", + ":trace_config_proto", + ], +) diff --git a/secretflow_serving/config/feature_config.proto b/secretflow_serving/config/feature_config.proto index 34814e4..abda2db 100644 --- a/secretflow_serving/config/feature_config.proto +++ b/secretflow_serving/config/feature_config.proto @@ -26,6 +26,7 @@ message FeatureSourceConfig { MockOptions mock_opts = 1; HttpOptions http_opts = 2; CsvOptions csv_opts = 3; + StreamingOptions streaming_opts = 4; } } @@ -80,4 +81,30 @@ message CsvOptions { // Id column name, associated with FeatureParam::query_datas // Query datas is a subset of id column string id_name = 2; + + // Optional. Only for Inferencer tool use. + // Defatult: false + bool streaming_mode = 11; + // Optional. Valid only if `streaming_mode=true`. + // This determines the size(byte) of each read batch. + int32 block_size = 12; +} + +// Only for Inferencer tool use. +message StreamingOptions { + // Input file path, specifies where to load data + string file_path = 1; + + // Input file format. + // Optional value: CSV + // Default: CSV + string file_format = 2; + + // Id column name, associated with FeatureParam::query_datas + // Query datas is a subset of id column + string id_name = 3; + + // Optional. + // This determines the size(byte) of each read batch. + int32 block_size = 12; } diff --git a/secretflow_serving/config/server_config.proto b/secretflow_serving/config/server_config.proto index 4ea1236..dbcbd3e 100644 --- a/secretflow_serving/config/server_config.proto +++ b/secretflow_serving/config/server_config.proto @@ -32,7 +32,8 @@ message ServerConfig { // e.g. 192.168.2.51 string host = 3; - // The port used for model inference. + // The port used for model predict service. + // Default: disable service int32 service_port = 4; // The port used for communication between parties serving. diff --git a/secretflow_serving/core/BUILD.bazel b/secretflow_serving/core/BUILD.bazel index 28228a4..5860e10 100644 --- a/secretflow_serving/core/BUILD.bazel +++ b/secretflow_serving/core/BUILD.bazel @@ -50,7 +50,8 @@ serving_cc_library( name = "types", hdrs = ["types.h"], deps = [ - "@com_github_eigenteam_eigen//:eigen3", + #"@com_github_eigenteam_eigen//:eigen3", + "@com_alipay_sf_heu//heu/library/numpy:matrix", ], ) diff --git a/secretflow_serving/core/types.h b/secretflow_serving/core/types.h index ce7fa52..1c397ac 100644 --- a/secretflow_serving/core/types.h +++ b/secretflow_serving/core/types.h @@ -14,7 +14,7 @@ #pragma once -#include "Eigen/Core" +#include "heu/library/numpy/eigen_traits.h" namespace secretflow::serving { diff --git a/secretflow_serving/feature_adapter/BUILD.bazel b/secretflow_serving/feature_adapter/BUILD.bazel index 8239a0b..1a5b7f4 100644 --- a/secretflow_serving/feature_adapter/BUILD.bazel +++ b/secretflow_serving/feature_adapter/BUILD.bazel @@ -22,6 +22,7 @@ serving_cc_library( ":file_adapter", ":http_adapter", ":mock_adapter", + ":streaming_adapter", ], ) @@ -86,6 +87,18 @@ serving_cc_library( alwayslink = True, ) +serving_cc_library( + name = "streaming_adapter", + srcs = ["streaming_adapter.cc"], + hdrs = ["streaming_adapter.h"], + deps = [ + ":feature_adapter_factory", + "//secretflow_serving/util:csv_util", + "@com_github_gflags_gflags//:gflags", + ], + alwayslink = True, +) + serving_cc_test( name = "file_adapter_test", srcs = ["file_adapter_test.cc"], @@ -110,3 +123,12 @@ serving_cc_test( ":mock_adapter", ], ) + +serving_cc_test( + name = "streaming_adapter_test", + srcs = ["streaming_adapter_test.cc"], + deps = [ + ":streaming_adapter", + "@com_github_brpc_brpc//:butil", + ], +) diff --git a/secretflow_serving/feature_adapter/feature_adapter.cc b/secretflow_serving/feature_adapter/feature_adapter.cc index c8b77f3..56fa544 100644 --- a/secretflow_serving/feature_adapter/feature_adapter.cc +++ b/secretflow_serving/feature_adapter/feature_adapter.cc @@ -41,15 +41,12 @@ void FeatureAdapter::CheckFeatureValid( const std::shared_ptr& features) { const auto& schema = features->schema(); if (feature_schema_->num_fields() > 0) { - SERVING_ENFORCE(schema->Equals(*feature_schema_), - errors::ErrorCode::NOT_FOUND, - "result schema does not match the request expect."); + CheckReferenceFields( + schema, std::const_pointer_cast(feature_schema_), + "result schema does not match the request expect."); } - SERVING_ENFORCE( - request.fs_param->query_datas().size() == features->num_rows(), - errors::ErrorCode::LOGIC_ERROR, - "query row_num {} should be equal to fetched row_num {}", - request.fs_param->query_datas().size(), features->num_rows()); + SERVING_ENFORCE_EQ(features->num_rows(), request.fs_param->query_datas_size(), + "fetched feature row num should be equal to query num"); } } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/file_adapter.cc b/secretflow_serving/feature_adapter/file_adapter.cc index ae940ac..8b184cf 100644 --- a/secretflow_serving/feature_adapter/file_adapter.cc +++ b/secretflow_serving/feature_adapter/file_adapter.cc @@ -14,12 +14,9 @@ #include "secretflow_serving/feature_adapter/file_adapter.h" -#include "arrow/compute/api.h" -#include "arrow/csv/api.h" -#include "arrow/io/api.h" - #include "secretflow_serving/feature_adapter/feature_adapter_factory.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/csv_util.h" namespace secretflow::serving::feature { @@ -29,12 +26,10 @@ FileAdapter::FileAdapter( const std::shared_ptr& feature_schema) : FeatureAdapter(spec, service_id, party_id, feature_schema), extractor_(feature_schema, spec_.csv_opts().file_path(), - spec_.csv_opts().id_name()) { - SERVING_ENFORCE(spec_.has_csv_opts(), errors::ErrorCode::INVALID_ARGUMENT, - "invalid mock options"); -} + spec_.csv_opts().id_name()) {} void FileAdapter::OnFetchFeature(const Request& request, Response* response) { + SERVING_ENFORCE_GT(request.fs_param->query_datas_size(), 0); response->features = extractor_.ExtractRows(feature_schema_, request.fs_param->query_datas()); } diff --git a/secretflow_serving/feature_adapter/file_adapter.h b/secretflow_serving/feature_adapter/file_adapter.h index 7c7efca..e9a55ce 100644 --- a/secretflow_serving/feature_adapter/file_adapter.h +++ b/secretflow_serving/feature_adapter/file_adapter.h @@ -14,9 +14,6 @@ #pragma once -#include -#include - #include "secretflow_serving/feature_adapter/feature_adapter.h" #include "secretflow_serving/util/csv_extractor.h" @@ -34,7 +31,7 @@ class FileAdapter : public FeatureAdapter { void OnFetchFeature(const Request& request, Response* response) override; private: - CSVExtractor extractor_; + csv::CSVExtractor extractor_; }; } // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/file_adapter_test.cc b/secretflow_serving/feature_adapter/file_adapter_test.cc index 558f270..0d62641 100644 --- a/secretflow_serving/feature_adapter/file_adapter_test.cc +++ b/secretflow_serving/feature_adapter/file_adapter_test.cc @@ -50,7 +50,7 @@ id,x1,x2,x3,x4 )TEXT"); FeatureSourceConfig config; - auto csv_opts = config.mutable_csv_opts(); + auto* csv_opts = config.mutable_csv_opts(); csv_opts->set_file_path(tmpfile.fname()); csv_opts->set_id_name("id"); diff --git a/secretflow_serving/feature_adapter/http_adapter.cc b/secretflow_serving/feature_adapter/http_adapter.cc index 99ae979..77e75a5 100644 --- a/secretflow_serving/feature_adapter/http_adapter.cc +++ b/secretflow_serving/feature_adapter/http_adapter.cc @@ -87,15 +87,14 @@ HttpFeatureAdapter::HttpFeatureAdapter( http_opts.timeout_ms() > 0 ? http_opts.timeout_ms() : kTimeoutMs, http_opts.connect_timeout_ms() > 0 ? http_opts.connect_timeout_ms() : kConnectTimeoutMs, - http_opts.has_tls_config() ? &http_opts.tls_config() : nullptr); - retry_count_ = - RetryPolicyFactory::GetInstance()->GetMaxRetryCount(channel_name); + http_opts.has_tls_config() ? &http_opts.tls_config() : nullptr, + http_opts.has_retry_policy_config() ? &http_opts.retry_policy_config() + : nullptr); } void HttpFeatureAdapter::OnFetchFeature(const Request& request, Response* response) { brpc::Controller cntl; - cntl.set_max_retry(retry_count_); cntl.http_request().uri() = spec_.http_opts().endpoint(); cntl.http_request().set_method(brpc::HTTP_METHOD_POST); cntl.http_request().set_content_type("application/json"); @@ -134,14 +133,12 @@ void HttpFeatureAdapter::OnFetchFeature(const Request& request, span_option.msg = fmt::format( "deserialize response context failed: request: {}, error: {}", spi_request, status.ToString()); - } else if (spi_response.status().code() != spis::ErrorCode::OK) { span_option.code = MappingErrorCode(spi_response.status().code()); span_option.msg = fmt::format( "fetch features response error, request: {}, msg: {}, code: {}", spi_request, spi_response.status().msg(), spi_response.status().code()); - } else if (spi_response.features().empty()) { span_option.code = errors::ErrorCode::IO_ERROR; span_option.msg = diff --git a/secretflow_serving/feature_adapter/http_adapter.h b/secretflow_serving/feature_adapter/http_adapter.h index 08eedd3..e63a7bc 100644 --- a/secretflow_serving/feature_adapter/http_adapter.h +++ b/secretflow_serving/feature_adapter/http_adapter.h @@ -38,7 +38,6 @@ class HttpFeatureAdapter : public FeatureAdapter { const Request& request); protected: - int retry_count_; std::unique_ptr channel_; std::vector feature_fields_; }; diff --git a/secretflow_serving/feature_adapter/streaming_adapter.cc b/secretflow_serving/feature_adapter/streaming_adapter.cc new file mode 100644 index 0000000..9612011 --- /dev/null +++ b/secretflow_serving/feature_adapter/streaming_adapter.cc @@ -0,0 +1,156 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/feature_adapter/streaming_adapter.h" + +#include "gflags/gflags.h" +#include "spdlog/spdlog.h" + +#include "secretflow_serving/feature_adapter/feature_adapter_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/csv_util.h" + +DEFINE_bool(inferencer_mode, false, + "streaming adpater is allowed only if inferencer_mode is true"); + +namespace secretflow::serving::feature { + +namespace { + +const char* kUnSetMagicToken = "__UNSET_MAGIC_TOKEN__"; + +void CheckIdsEqual(const FeatureAdapter::Request& request, + const std::shared_ptr& id_array) { + SERVING_ENFORCE_EQ(id_array->length(), request.fs_param->query_datas_size(), + "id_array length mismatch with query_datas_size."); + auto str_array = std::static_pointer_cast(id_array); + for (int i = 0; i < request.fs_param->query_datas_size(); ++i) { + SERVING_ENFORCE_EQ(request.fs_param->query_datas(i), str_array->Value(i)); + } +} +} // namespace + +StreamingAdapter::StreamingAdapter( + const FeatureSourceConfig& spec, const std::string& service_id, + const std::string& party_id, + const std::shared_ptr& feature_schema) + : FeatureAdapter(spec, service_id, party_id, feature_schema), + last_context_token_(kUnSetMagicToken) { + SERVING_ENFORCE( + FLAGS_inferencer_mode, errors::ErrorCode::LOGIC_ERROR, + "streaming adpater is allowed only when using the inferencer tool"); + + SERVING_ENFORCE(spec_.has_streaming_opts(), + errors::ErrorCode::INVALID_ARGUMENT, + "invalid feature source streaming options"); + if (spec_.streaming_opts().file_format().empty()) { + spec_.mutable_streaming_opts()->set_file_format("CSV"); + } + SERVING_ENFORCE_EQ(spec_.streaming_opts().file_format(), "CSV"); + + std::unordered_map> col_types; + for (const auto& f : feature_schema->fields()) { + col_types.emplace(f->name(), f->type()); + } + col_types.emplace(spec_.streaming_opts().id_name(), arrow::utf8()); + + arrow::csv::ReadOptions read_opts = arrow::csv::ReadOptions::Defaults(); + if (spec_.streaming_opts().block_size() > 0) { + read_opts.block_size = spec_.streaming_opts().block_size(); + } + streaming_reader_ = csv::BuildStreamingReader( + spec_.streaming_opts().file_path(), std::move(col_types), read_opts); +} + +void StreamingAdapter::OnFetchFeature(const Request& request, + Response* response) { + SERVING_ENFORCE_GT(request.fs_param->query_datas_size(), 0); + + std::lock_guard lock(mux_); + + if (!request.fs_param->query_context().empty() && + request.fs_param->query_context() == last_context_token_) { + // The retry request attempts to retrieve previously read data, possibly + // due to an exception in some other logic. + response->features = last_batch_; + CheckIdsEqual(request, response->features->GetColumnByName( + spec_.streaming_opts().id_name())); + return; + } + // cleanup cache + last_batch_ = nullptr; + last_context_token_ = kUnSetMagicToken; + + int64_t idx = 0; + std::vector> batches; + if (cur_batch_) { + std::shared_ptr slice_batch; + if (request.fs_param->query_datas_size() < + cur_batch_->num_rows() - cur_offset_) { + idx += request.fs_param->query_datas_size(); + slice_batch = + cur_batch_->Slice(cur_offset_, request.fs_param->query_datas_size()); + cur_offset_ += request.fs_param->query_datas_size(); + } else { + idx += cur_batch_->num_rows() - cur_offset_; + slice_batch = cur_batch_->Slice(cur_offset_); + cur_offset_ = 0; + cur_batch_ = nullptr; + } + batches.emplace_back(std::move(slice_batch)); + } + + std::shared_ptr batch; + while (request.fs_param->query_datas_size() > idx) { + SERVING_CHECK_ARROW_STATUS(streaming_reader_->ReadNext(&batch)); + SERVING_ENFORCE(batch, errors::LOGIC_ERROR); + SERVING_ENFORCE_GT( + batch->num_rows(), 0, + "may be because `block_size` is configured too small: {}", + spec_.streaming_opts().block_size()); + + if (request.fs_param->query_datas_size() - idx < batch->num_rows()) { + auto slice_batch = + batch->Slice(0, request.fs_param->query_datas_size() - idx); + cur_batch_ = batch; + cur_offset_ = slice_batch->num_rows(); + idx += slice_batch->num_rows(); + batches.emplace_back(std::move(slice_batch)); + break; + } else { + idx += batch->num_rows(); + cur_batch_ = nullptr; + cur_offset_ = 0; + batches.emplace_back(std::move(batch)); + } + } + + if (batches.size() > 1) { + std::shared_ptr table; + SERVING_GET_ARROW_RESULT(arrow::Table::FromRecordBatches(batches), table); + SERVING_GET_ARROW_RESULT(table->CombineChunksToBatch(), response->features); + } else { + response->features = std::move(batches[0]); + } + CheckIdsEqual(request, response->features->GetColumnByName( + spec_.streaming_opts().id_name())); + // cache current result. + last_batch_ = response->features; + last_context_token_ = request.fs_param->query_context(); +} + +REGISTER_ADAPTER(FeatureSourceConfig::OptionsCase::kStreamingOpts, + StreamingAdapter); + +} // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/streaming_adapter.h b/secretflow_serving/feature_adapter/streaming_adapter.h new file mode 100644 index 0000000..0a335fc --- /dev/null +++ b/secretflow_serving/feature_adapter/streaming_adapter.h @@ -0,0 +1,45 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "arrow/csv/api.h" + +#include "secretflow_serving/feature_adapter/feature_adapter.h" + +namespace secretflow::serving::feature { + +class StreamingAdapter : public FeatureAdapter { + public: + StreamingAdapter(const FeatureSourceConfig& spec, + const std::string& service_id, const std::string& party_id, + const std::shared_ptr& feature_schema); + + ~StreamingAdapter() override = default; + + protected: + void OnFetchFeature(const Request& request, Response* response) override; + + private: + std::shared_ptr streaming_reader_; + std::shared_ptr cur_batch_; + int64_t cur_offset_; + + std::shared_ptr last_batch_; + std::string last_context_token_; + + std::mutex mux_; +}; + +} // namespace secretflow::serving::feature diff --git a/secretflow_serving/feature_adapter/streaming_adapter_test.cc b/secretflow_serving/feature_adapter/streaming_adapter_test.cc new file mode 100644 index 0000000..b7627f1 --- /dev/null +++ b/secretflow_serving/feature_adapter/streaming_adapter_test.cc @@ -0,0 +1,289 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/feature_adapter/streaming_adapter.h" + +#include "arrow/api.h" +#include "arrow/ipc/api.h" +#include "butil/files/temp_file.h" +#include "gflags/gflags.h" +#include "gtest/gtest.h" + +#include "secretflow_serving/feature_adapter/feature_adapter_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/csv_util.h" + +#include "secretflow_serving/protos/feature.pb.h" + +DECLARE_bool(inferencer_mode); + +namespace secretflow::serving::feature { + +namespace { + +const std::string kTestModelServiceId = "test_service_id"; +const std::string kTestPartyId = "alice"; + +void CheckRecordBatchEqual(const std::shared_ptr& src, + const std::shared_ptr& dst) { + SERVING_ENFORCE_EQ(src->num_columns(), dst->num_columns()); + CheckReferenceFields(src->schema(), dst->schema()); + + for (int i = 0; i < dst->num_columns(); ++i) { + auto dst_array = dst->column(i); + auto src_array = src->GetColumnByName(dst->schema()->field_names()[i]); + SERVING_ENFORCE(src_array, errors::ErrorCode::UNEXPECTED_ERROR); + SERVING_ENFORCE(src_array->Equals(dst_array), + errors::ErrorCode::UNEXPECTED_ERROR); + } +} + +} // namespace + +class StreamingAdapterTest : public ::testing::Test { + protected: + void SetUp() override { FLAGS_inferencer_mode = true; } + void TearDown() override {} +}; + +TEST_F(StreamingAdapterTest, Works) { + butil::TempFile tmpfile; + tmpfile.save(1 + R"TEXT( +id,x1,x2,x3,x4 +0,0,1.0,alice,dummy +1,1,11.0,bob,dummy +2,2,21.0,carol,dummy +3,3,31.0,dave,dummy +)TEXT"); + + FeatureSourceConfig config; + auto* streaming_opts = config.mutable_streaming_opts(); + streaming_opts->set_file_path(tmpfile.fname()); + streaming_opts->set_id_name("id"); + + std::vector> fields = { + arrow::field("x1", arrow::int32()), arrow::field("x2", arrow::float32()), + arrow::field("x3", arrow::utf8())}; + auto model_schema = arrow::schema(fields); + + fields.emplace_back(arrow::field("id", arrow::utf8())); + auto model_schema_with_id = arrow::schema(fields); + + auto adapter = FeatureAdapterFactory::GetInstance()->Create( + config, kTestModelServiceId, kTestPartyId, model_schema); + + // fist request. + FeatureParam fs_params; + fs_params.add_query_datas("0"); + fs_params.set_query_context("test_contex_0"); + + FeatureAdapter::Request request; + request.fs_param = &fs_params; + FeatureAdapter::Response response; + + ASSERT_NO_THROW(adapter->FetchFeature(request, &response)); + ASSERT_TRUE(response.features); + ASSERT_EQ(1, response.features->num_rows()); + CheckReferenceFields(response.features->schema(), model_schema); + { + std::shared_ptr x1; + std::shared_ptr x2; + std::shared_ptr x3; + std::shared_ptr x4; + using arrow::ipc::internal::json::ArrayFromJSON; + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int32(), "[0]"), x1); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::float32(), "[1.0]"), x2); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), R"(["alice"])"), x3); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), R"(["0"])"), x4); + const std::shared_ptr expect_features = + MakeRecordBatch(model_schema_with_id, 1, {x1, x2, x3, x4}); + CheckRecordBatchEqual(expect_features, response.features); + } + + // second request + fs_params.clear_query_datas(); + fs_params.add_query_datas("1"); + fs_params.set_query_context("test_contex_1"); + ASSERT_NO_THROW(adapter->FetchFeature(request, &response)); + ASSERT_TRUE(response.features); + ASSERT_EQ(1, response.features->num_rows()); + CheckReferenceFields(response.features->schema(), model_schema); + { + std::shared_ptr x1; + std::shared_ptr x2; + std::shared_ptr x3; + std::shared_ptr x4; + using arrow::ipc::internal::json::ArrayFromJSON; + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int32(), "[1]"), x1); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::float32(), "[11.0]"), x2); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), R"(["bob"])"), x3); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), R"(["1"])"), x4); + const std::shared_ptr expect_features = + MakeRecordBatch(model_schema_with_id, 1, {x1, x2, x3, x4}); + CheckRecordBatchEqual(expect_features, response.features); + } + + // third request same as second request + fs_params.clear_query_datas(); + fs_params.add_query_datas("1"); + fs_params.set_query_context("test_contex_1"); + ASSERT_NO_THROW(adapter->FetchFeature(request, &response)); + ASSERT_TRUE(response.features); + ASSERT_EQ(1, response.features->num_rows()); + CheckReferenceFields(response.features->schema(), model_schema); + { + std::shared_ptr x1; + std::shared_ptr x2; + std::shared_ptr x3; + std::shared_ptr x4; + using arrow::ipc::internal::json::ArrayFromJSON; + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::int32(), "[1]"), x1); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::float32(), "[11.0]"), x2); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), R"(["bob"])"), x3); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), R"(["1"])"), x4); + const std::shared_ptr expect_features = + MakeRecordBatch(model_schema_with_id, 1, {x1, x2, x3, x4}); + CheckRecordBatchEqual(expect_features, response.features); + } + + // exception request + fs_params.clear_query_datas(); + fs_params.add_query_datas("0"); + fs_params.set_query_context("test_contex_1"); + ASSERT_THROW(adapter->FetchFeature(request, &response), Exception); + + fs_params.clear_query_datas(); + fs_params.add_query_datas("3"); + fs_params.set_query_context("test_contex_2"); + ASSERT_THROW(adapter->FetchFeature(request, &response), Exception); +} + +struct StreamingModeParam { + std::string content; + int32_t adapter_block_size; + int32_t request_block_size; + std::shared_ptr schema; +}; + +class StreamingAdapterStreamingModeTest + : public ::testing::TestWithParam { + protected: + void SetUp() override { FLAGS_inferencer_mode = true; } + void TearDown() override {} +}; + +TEST_P(StreamingAdapterStreamingModeTest, Works) { + auto param = GetParam(); + + butil::TempFile tmpfile; + tmpfile.save(param.content.data()); + + FeatureSourceConfig config; + auto* streaming_opts = config.mutable_streaming_opts(); + streaming_opts->set_file_path(tmpfile.fname()); + streaming_opts->set_id_name("id"); + streaming_opts->set_block_size(param.adapter_block_size); + + auto adapter = FeatureAdapterFactory::GetInstance()->Create( + config, kTestModelServiceId, kTestPartyId, param.schema); + + auto req_reader_opts = arrow::csv::ReadOptions::Defaults(); + if (param.request_block_size > 0) { + req_reader_opts.block_size = param.request_block_size; + } + auto req_reader = csv::BuildStreamingReader( + tmpfile.fname(), {{"id", arrow::utf8()}}, req_reader_opts); + std::shared_ptr batch; + size_t idx = 0; + while (true) { + ++idx; + SERVING_CHECK_ARROW_STATUS(req_reader->ReadNext(&batch)); + if (batch == nullptr) { + break; + } + std::cout << "req batch length: " << batch->num_rows() << std::endl; + + auto id_array = + std::static_pointer_cast(batch->column(0)); + + FeatureParam fs_params; + // fs_params.set_query_context(std::to_string(idx)); + for (int64_t i = 0; i < id_array->length(); ++i) { + auto item = id_array->Value(i); + fs_params.add_query_datas(item.data(), item.length()); + } + FeatureAdapter::Request request; + request.fs_param = &fs_params; + FeatureAdapter::Response response; + ASSERT_NO_THROW(adapter->FetchFeature(request, &response)); + ASSERT_TRUE(response.features); + ASSERT_EQ(batch->num_rows(), response.features->num_rows()); + CheckReferenceFields(response.features->schema(), param.schema); + } +} + +INSTANTIATE_TEST_SUITE_P( + StreamingAdapterStreamingModeTestSuit, StreamingAdapterStreamingModeTest, + ::testing::Values( + StreamingModeParam{1 + R"TEXT( +id,x1,x2,x3,x4 +0,0,1.0,alice,dummy +1,1,11.0,bob,dummy +2,2,21.0,carol,dummy +3,3,31.0,dave,dummy +)TEXT", + 0, 0, + arrow::schema({arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::float32()), + arrow::field("x3", arrow::utf8())})}, + StreamingModeParam{1 + R"TEXT( +id,x1,x2,x3,x4 +0,0,1.0,alice,dummy +1,1,11.0,bob,dummy +2,2,21.0,carol,dummy +)TEXT", + 15 /*read one row once*/, 0 /*request all*/, + arrow::schema({arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::float32()), + arrow::field("x3", arrow::utf8())})}, + StreamingModeParam{1 + R"TEXT( +id,x1,x2,x3,x4 +0,0,1.0,alice,dummy +1,1,11.0,bob,dummy +2,2,21.0,carol,dummy +3,3,31.0,dave,dummy +3,3,31.0,dave,dummy +3,3,31.0,dave,dummy +)TEXT", + 30 /*read two row once*/, + 20 /*request one row once*/, + arrow::schema({arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::float32()), + arrow::field("x3", arrow::utf8())})}, + StreamingModeParam{ + 1 + R"TEXT( +id,x1,x2,x3,x4 +0,0,1.0,alice,dummy +1,1,11.0,bob,dummy +2,2,21.0,carol,dummy +3,3,31.0,dave,dummy +3,3,31.0,dave,dummy +3,3,31.0,dave,dummy +)TEXT", + 0 /*read all once*/, 20 /*request one row once*/, + arrow::schema({arrow::field("x1", arrow::int32()), + arrow::field("x2", arrow::float32()), + arrow::field("x3", arrow::utf8())})})); + +} // namespace secretflow::serving::feature diff --git a/secretflow_serving/framework/BUILD.bazel b/secretflow_serving/framework/BUILD.bazel index 75db827..19253d2 100644 --- a/secretflow_serving/framework/BUILD.bazel +++ b/secretflow_serving/framework/BUILD.bazel @@ -65,6 +65,7 @@ serving_cc_library( deps = [ "//secretflow_serving/apis:model_service_cc_proto", "//secretflow_serving/core:exception", + "//secretflow_serving/util:he_mgm", ], ) diff --git a/secretflow_serving/framework/executable.cc b/secretflow_serving/framework/executable.cc index 291805e..a67f270 100644 --- a/secretflow_serving/framework/executable.cc +++ b/secretflow_serving/framework/executable.cc @@ -27,11 +27,11 @@ void Executable::Run(Task& task) { SERVING_ENFORCE(task.id < executors_.size(), errors::ErrorCode::LOGIC_ERROR); auto executor = executors_[task.id]; if (task.features) { - task.outputs = executor.Run(task.features); + task.outputs = executor.Run(task.requester_id, task.features); } else { SERVING_ENFORCE(!task.prev_node_outputs.empty(), errors::ErrorCode::LOGIC_ERROR); - task.outputs = executor.Run(task.prev_node_outputs); + task.outputs = executor.Run(task.requester_id, task.prev_node_outputs); } SPDLOG_DEBUG("Executable::Run end, task.outputs.size:{}", diff --git a/secretflow_serving/framework/executable.h b/secretflow_serving/framework/executable.h index 85404e9..8d9271d 100644 --- a/secretflow_serving/framework/executable.h +++ b/secretflow_serving/framework/executable.h @@ -23,6 +23,8 @@ class Executable { struct Task { size_t id; + std::string requester_id; + // input // `features` or `prev_node_outputs` should be set std::shared_ptr features; diff --git a/secretflow_serving/framework/execute_context.cc b/secretflow_serving/framework/execute_context.cc index 03ee60a..f3f113f 100644 --- a/secretflow_serving/framework/execute_context.cc +++ b/secretflow_serving/framework/execute_context.cc @@ -109,8 +109,6 @@ RemoteExecute::RemoteExecute(ExecuteContext ctx, const std::string& target_id, ::google::protobuf::RpcChannel* channel) : ExecuteBase{std::move(ctx)}, channel_(channel) { SetTarget(target_id); - cntl_.set_max_retry( - RetryPolicyFactory::GetInstance()->GetMaxRetryCount(target_id)); span_option_.cntl = &cntl_; span_option_.is_client = true; span_option_.party_id = ctx.local_id; diff --git a/secretflow_serving/framework/executor.cc b/secretflow_serving/framework/executor.cc index 21811c1..50dd472 100644 --- a/secretflow_serving/framework/executor.cc +++ b/secretflow_serving/framework/executor.cc @@ -88,12 +88,16 @@ class ExecuteScheduler : public std::enable_shared_from_this { ExecuteScheduler( const std::shared_ptr>& node_items, - std::shared_ptr execution, ThreadPool* thread_pool) + std::shared_ptr execution, ThreadPool* thread_pool, + const std::string& self_party_id, + const std::vector& party_ids, + const std::string& requester_id) : node_items_(node_items), execution_(std::move(execution)), context_(execution_->GetOutputNodeNum()), thread_pool_(thread_pool), - propagator_(execution_->nodes()), + propagator_(execution_->nodes(), self_party_id, party_ids, + requester_id), sched_count_(0) {} void ExecuteNode(const std::string& node_name) { @@ -218,8 +222,12 @@ class ExecuteScheduler : public std::enable_shared_from_this { std::exception_ptr task_exception_; }; -Executor::Executor(const std::shared_ptr& execution) - : execution_(execution) { +Executor::Executor(const std::shared_ptr& execution, + const std::string& self_party_id, + const std::vector& party_ids) + : execution_(execution), + self_party_id_(self_party_id), + party_ids_(party_ids) { // create op_kernel node_items_ = std::make_shared>(); @@ -267,9 +275,11 @@ Executor::Executor(const std::shared_ptr& execution) } std::vector Executor::Run( + const std::string& requester_id, std::shared_ptr& features) { - auto sched = std::make_shared(node_items_, execution_, - ThreadPool::GetInstance()); + auto sched = std::make_shared( + node_items_, execution_, ThreadPool::GetInstance(), self_party_id_, + party_ids_, requester_id); sched->Schedule(features); auto task_exception = sched->GetTaskException(); @@ -282,12 +292,15 @@ std::vector Executor::Run( } std::vector Executor::Run( + const std::string& requester_id, std::unordered_map>>& prev_node_outputs) { SERVING_ENFORCE(!execution_->IsGraphEntry(), errors::ErrorCode::LOGIC_ERROR); - auto sched = std::make_shared(node_items_, execution_, - ThreadPool::GetInstance()); + auto sched = std::make_shared( + node_items_, execution_, ThreadPool::GetInstance(), self_party_id_, + party_ids_, requester_id); + sched->Schedule(prev_node_outputs); auto task_exception = sched->GetTaskException(); diff --git a/secretflow_serving/framework/executor.h b/secretflow_serving/framework/executor.h index f01e083..4f0f099 100644 --- a/secretflow_serving/framework/executor.h +++ b/secretflow_serving/framework/executor.h @@ -34,7 +34,9 @@ struct NodeItem { class Executor { public: - explicit Executor(const std::shared_ptr& execution); + explicit Executor(const std::shared_ptr& execution, + const std::string& self_party_id, + const std::vector& party_ids); ~Executor() = default; const std::shared_ptr& GetInputFeatureSchema() const { @@ -43,16 +45,23 @@ class Executor { // for middle execution std::vector Run( + const std::string& requester_id, std::unordered_map>>& prev_node_outputs); // for entry execution - std::vector Run(std::shared_ptr& features); + std::vector Run(const std::string& requester_id, + std::shared_ptr& features); private: std::shared_ptr execution_; + std::string self_party_id_; + std::vector party_ids_; + + std::vector entry_node_names_; + std::shared_ptr> node_items_; std::unordered_map(0, std::move(executino_def), std::move(nodes), true, true); - auto executor = std::make_shared(execution); + auto executor = std::make_shared(execution, "alice", + std::vector{"bob"}); // mock input std::shared_ptr inputs; @@ -293,7 +294,7 @@ TEST_F(ExecutorTest, MassiveWorks) { inputs = MakeRecordBatch(input_schema, 4, {array_0}); } // run - auto output = executor->Run(inputs); + auto output = executor->Run("test_requester_id", inputs); // build expect auto expect_output_schema = @@ -378,7 +379,8 @@ TEST_F(ExecutorTest, ComplexMassiveWorks) { auto execution = std::make_shared(0, std::move(executino_def), std::move(nodes), true, true); - auto executor = std::make_shared(execution); + auto executor = std::make_shared(execution, "alice", + std::vector{"bob"}); // mock input std::shared_ptr inputs; @@ -394,7 +396,7 @@ TEST_F(ExecutorTest, ComplexMassiveWorks) { } // run - auto output = executor->Run(inputs); + auto output = executor->Run("test_requester_id", inputs); // build expect auto expect_output_schema = @@ -482,7 +484,8 @@ TEST_F(ExecutorTest, FeatureInput) { auto execution = std::make_shared(0, std::move(executino_def), std::move(nodes), true, true); - auto executor = std::make_shared(execution); + auto executor = std::make_shared(execution, "alice", + std::vector{"bob"}); // mock input std::shared_ptr inputs; @@ -497,7 +500,7 @@ TEST_F(ExecutorTest, FeatureInput) { inputs = MakeRecordBatch(input_schema, 4, {array_0}); } // run - auto output = executor->Run(inputs); + auto output = executor->Run("test_requester_id", inputs); // build expect auto expect_output_schema = @@ -583,7 +586,8 @@ TEST_F(ExecutorTest, ExceptionWorks) { auto execution = std::make_shared(0, std::move(executino_def), std::move(nodes), true, true); - auto executor = std::make_shared(execution); + auto executor = std::make_shared(execution, "alice", + std::vector{"bob"}); // mock input std::shared_ptr inputs; @@ -598,7 +602,8 @@ TEST_F(ExecutorTest, ExceptionWorks) { inputs = MakeRecordBatch(input_schema, 4, {array_0}); } // run - EXPECT_THROW(executor->Run(inputs), ::secretflow::serving::Exception); + EXPECT_THROW(executor->Run("test_requester_id", inputs), + ::secretflow::serving::Exception); // expect EXPECT_EQ(ThreadPool::GetInstance()->GetTaskSize(), 0); @@ -673,7 +678,8 @@ TEST_F(ExecutorTest, ExceptionComplexMassiveWorks) { auto execution = std::make_shared(0, std::move(executino_def), std::move(nodes), true, true); - auto executor = std::make_shared(execution); + auto executor = std::make_shared(execution, "alice", + std::vector{"bob"}); // mock input std::shared_ptr inputs; @@ -688,7 +694,8 @@ TEST_F(ExecutorTest, ExceptionComplexMassiveWorks) { inputs = MakeRecordBatch(input_schema, 4, {array_0}); } // run - EXPECT_THROW(executor->Run(inputs), ::secretflow::serving::Exception); + EXPECT_THROW(executor->Run("test_requester_id", inputs), + ::secretflow::serving::Exception); // wait for thread pool to pop remain tasks executor.reset(); @@ -766,7 +773,8 @@ TEST_F(ExecutorTest, PrevNodeDataInput) { std::unordered_map>{ {"node_c", nodes["node_c"]}, {"node_d", nodes["node_d"]}}, false, true); - auto executor = std::make_shared(execution); + auto executor = std::make_shared(execution, "alice", + std::vector{"bob"}); // mock input auto input_schema = @@ -791,7 +799,7 @@ TEST_F(ExecutorTest, PrevNodeDataInput) { "node_b", std::vector>{output_b}); // run - auto output = executor->Run(prev_node_io); + auto output = executor->Run("test_requester_id", prev_node_io); // build expect auto expect_output_schema = diff --git a/secretflow_serving/framework/model_info_collector.cc b/secretflow_serving/framework/model_info_collector.cc index 503e5e0..a978ddc 100644 --- a/secretflow_serving/framework/model_info_collector.cc +++ b/secretflow_serving/framework/model_info_collector.cc @@ -15,58 +15,39 @@ #include "secretflow_serving/framework/model_info_collector.h" #include +#include #include #include "spdlog/spdlog.h" #include "secretflow_serving/core/exception.h" #include "secretflow_serving/server/trace/trace.h" +#include "secretflow_serving/util/he_mgm.h" #include "secretflow_serving/util/utils.h" #include "secretflow_serving/apis/model_service.pb.h" namespace secretflow::serving { -namespace { - -using std::invoke_result_t; - -class RetryRunner { - public: - RetryRunner(uint32_t retry_counts, uint32_t retry_interval_ms) - : retry_counts_(retry_counts), retry_interval_ms_(retry_interval_ms) {} - - template >>> - bool Run(Func&& f, Args&&... args) const { - auto runner_func = [&] { - return std::invoke(std::forward(f), std::forward(args)...); - }; - for (uint32_t i = 0; i != retry_counts_; ++i) { - if (!runner_func()) { - std::this_thread::sleep_for( - std::chrono::milliseconds(retry_interval_ms_)); - } else { - return true; - } - } - return false; +ModelInfoCollector::ModelInfoCollector(Options opts) : opts_(std::move(opts)) { + if (opts_.model_bundle->graph().party_num() > 0) { + SERVING_ENFORCE_EQ( + opts_.model_bundle->graph().party_num(), + static_cast(opts_.remote_channel_map->size()), + "serving party num mishmatch with graph party num, {} vs {}", + opts_.remote_channel_map->size(), + opts_.model_bundle->graph().party_num()); } - private: - uint32_t retry_counts_; - uint32_t retry_interval_ms_; -}; - -} // namespace - -ModelInfoCollector::ModelInfoCollector(Options opts) : opts_(std::move(opts)) { // build model_info_ model_info_.set_name(opts_.model_bundle->name()); model_info_.set_desc(opts_.model_bundle->desc()); + auto* graph_view = model_info_.mutable_graph_view(); graph_view->set_version(opts_.model_bundle->graph().version()); + graph_view->set_party_num(opts_.model_bundle->graph().party_num()); + graph_view->mutable_he_info()->set_pk_buf( + opts_.model_bundle->graph().he_config().pk_buf()); for (const auto& node : opts_.model_bundle->graph().node_list()) { NodeView view; view.set_name(node.name()); @@ -81,6 +62,14 @@ ModelInfoCollector::ModelInfoCollector(Options opts) : opts_(std::move(opts)) { SPDLOG_INFO("local model info: party: {} : {}", opts_.self_party_id, PbToJson(&model_info_)); + + if (opts_.model_bundle->graph().has_he_config() && + !opts_.model_bundle->graph().he_config().sk_buf().empty()) { + he::HeKitMgm::GetInstance()->InitLocalKit( + opts_.model_bundle->graph().he_config().pk_buf(), + opts_.model_bundle->graph().he_config().sk_buf(), + opts_.model_bundle->graph().he_config().encode_scale()); + } } void ModelInfoCollector::DoCollect() { @@ -105,6 +94,9 @@ bool ModelInfoCollector::TryCollect( const std::string& remote_party_id, const std::unique_ptr<::google::protobuf::RpcChannel>& channel) { brpc::Controller cntl; + // close brpc retry to make action controlled by us + cntl.set_max_retry(0); + apis::GetModelInfoResponse response; apis::GetModelInfoRequest request; request.mutable_service_spec()->set_id(opts_.service_id); diff --git a/secretflow_serving/framework/model_info_processor.cc b/secretflow_serving/framework/model_info_processor.cc index 3c7cd62..85b5270 100644 --- a/secretflow_serving/framework/model_info_processor.cc +++ b/secretflow_serving/framework/model_info_processor.cc @@ -20,6 +20,7 @@ #include "spdlog/spdlog.h" #include "secretflow_serving/core/exception.h" +#include "secretflow_serving/util/he_mgm.h" #include "secretflow_serving/apis/model_service.pb.h" @@ -60,7 +61,7 @@ void ModelInfoProcessor::CheckAndSetSpecificMap() { const auto& local_graph_view = local_model_info_->graph_view(); - for (auto& [remote_party_id, model_info] : *remote_model_info_) { + for (const auto& [remote_party_id, model_info] : *remote_model_info_) { SERVING_ENFORCE_EQ(model_info.name(), local_model_info_->name(), "model name mismatch with {}: {}, local: {}: {}", remote_party_id, model_info.name(), local_party_id_, @@ -78,6 +79,18 @@ void ModelInfoProcessor::CheckAndSetSpecificMap() { remote_party_id, graph_view.execution_list_size(), local_party_id_, local_graph_view.execution_list_size()); + if (model_info.graph_view().has_he_info() && + !model_info.graph_view().he_info().pk_buf().empty()) { + SERVING_ENFORCE_EQ( + model_info.graph_view().he_info().encode_scale(), + local_model_info_->graph_view().he_info().encode_scale(), + "he encode scale mismatch, {}: {}, local: {}", remote_party_id, + model_info.graph_view().he_info().encode_scale(), + local_model_info_->graph_view().he_info().encode_scale()); + he::HeKitMgm::GetInstance()->InitDstKit( + remote_party_id, model_info.graph_view().he_info().pk_buf()); + } + CheckNodeViewList(graph_view.node_list(), remote_party_id); for (int i = 0; i != local_graph_view.execution_list_size(); ++i) { diff --git a/secretflow_serving/framework/predictor.cc b/secretflow_serving/framework/predictor.cc index 7b71a08..c488161 100644 --- a/secretflow_serving/framework/predictor.cc +++ b/secretflow_serving/framework/predictor.cc @@ -25,13 +25,26 @@ namespace secretflow::serving { namespace { -void DealFinalResult(apis::NodeIo& node_io, apis::PredictResponse* response) { +void DealFinalResult(apis::NodeIo& node_io, const apis::PredictRequest* request, + const std::string& party_id, + apis::PredictResponse* response) { SERVING_ENFORCE(node_io.ios_size() == 1, errors::ErrorCode::LOGIC_ERROR); const auto& ios = node_io.ios(0); SERVING_ENFORCE(ios.datas_size() == 1, errors::ErrorCode::LOGIC_ERROR); + std::shared_ptr record_batch = DeserializeRecordBatch(ios.datas(0)); + int64_t sample_num = 0; + if (request->predefined_features_size() > 0) { + sample_num = CountSampleNum(request->predefined_features()); + } else { + sample_num = request->fs_params().find(party_id)->second.query_datas_size(); + } + SERVING_ENFORCE_EQ( + record_batch->num_rows(), sample_num, + "The number of calculated results does not match the number of requests"); + std::vector results(record_batch->num_rows()); for (int64_t i = 0; i != record_batch->num_rows(); ++i) { results[i] = response->add_results(); @@ -95,7 +108,8 @@ void Predictor::Predict(const apis::PredictRequest* request, exec->GetOutputs(&node_io_map); } - } else if (e->GetDispatchType() == DispatchType::DP_ANYONE) { + } else if (e->GetDispatchType() == DispatchType::DP_ANYONE || + e->GetDispatchType() == DispatchType::DP_SELF) { // exec locally if (execution_core_) { execute_locally(std::move(ctx), &node_io_map); @@ -118,6 +132,14 @@ void Predictor::Predict(const apis::PredictRequest* request, exec->WaitToFinish(); exec->GetOutputs(&node_io_map); } + } else if (e->GetDispatchType() == DispatchType::DP_PEER) { + SERVING_ENFORCE(opts_.channels->size() == 1, + errors::ErrorCode::UNEXPECTED_ERROR); + auto iter = opts_.channels->begin(); + auto exec = BuildRemoteExecute(std::move(ctx), iter->first, iter->second); + exec->Run(); + exec->WaitToFinish(); + exec->GetOutputs(&node_io_map); } else { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "unsupported dispatch type: {}", @@ -128,7 +150,7 @@ void Predictor::Predict(const apis::PredictRequest* request, auto iter = node_io_map.find(exit_node_name); SERVING_ENFORCE(iter != node_io_map.end(), errors::ErrorCode::UNEXPECTED_ERROR); - DealFinalResult(iter->second, response); + DealFinalResult(iter->second, request, opts_.party_id, response); } std::unique_ptr Predictor::BuildRemoteExecute( diff --git a/secretflow_serving/framework/propagator.cc b/secretflow_serving/framework/propagator.cc index f16b776..0e9cd06 100644 --- a/secretflow_serving/framework/propagator.cc +++ b/secretflow_serving/framework/propagator.cc @@ -17,11 +17,25 @@ namespace secretflow::serving { Propagator::Propagator( - const std::unordered_map>& nodes) { + const std::unordered_map>& nodes, + const std::string& self_party_id, const std::vector& party_ids, + const std::string& requester_id) { + std::set others_party_ids; + for (const auto& id : party_ids) { + if (id == self_party_id) { + continue; + } + others_party_ids.insert(id); + } + for (const auto& [node_name, node] : nodes) { auto frame = std::make_unique(); frame->pending_count = node->GetInputNum(); frame->compute_ctx.inputs.resize(frame->pending_count); + frame->compute_ctx.other_party_ids = others_party_ids; + frame->compute_ctx.self_id = self_party_id; + frame->compute_ctx.requester_id = requester_id; + frame->compute_ctx.he_kit_mgm = he::HeKitMgm::GetInstance(); SERVING_ENFORCE(node_frame_map_.emplace(node_name, std::move(frame)).second, errors::ErrorCode::LOGIC_ERROR); diff --git a/secretflow_serving/framework/propagator.h b/secretflow_serving/framework/propagator.h index 296f0ad..84c1f04 100644 --- a/secretflow_serving/framework/propagator.h +++ b/secretflow_serving/framework/propagator.h @@ -28,7 +28,10 @@ struct FrameState { class Propagator { public: explicit Propagator( - const std::unordered_map>& nodes); + const std::unordered_map>& nodes, + const std::string& self_party_id, + const std::vector& party_ids, + const std::string& requester_id); FrameState* GetFrame(const std::string& node_name); diff --git a/secretflow_serving/ops/BUILD.bazel b/secretflow_serving/ops/BUILD.bazel index 79cf9c2..20988f8 100644 --- a/secretflow_serving/ops/BUILD.bazel +++ b/secretflow_serving/ops/BUILD.bazel @@ -25,6 +25,10 @@ serving_cc_library( ":arrow_processing", ":dot_product", ":merge_y", + "//secretflow_serving/ops/phe_linear:phe_2p_decrypt_peer_y", + "//secretflow_serving/ops/phe_linear:phe_2p_dot_product", + "//secretflow_serving/ops/phe_linear:phe_2p_merge_y", + "//secretflow_serving/ops/phe_linear:phe_2p_reduce", ], ) @@ -90,6 +94,7 @@ serving_cc_library( "//secretflow_serving/core:exception", "//secretflow_serving/protos:op_cc_proto", "//secretflow_serving/util:arrow_helper", + "//secretflow_serving/util:he_mgm", ], ) diff --git a/secretflow_serving/ops/graph_version.h b/secretflow_serving/ops/graph_version.h index 67d4b1f..84c54e0 100644 --- a/secretflow_serving/ops/graph_version.h +++ b/secretflow_serving/ops/graph_version.h @@ -16,7 +16,7 @@ // Version upgrade when `GraphDef` changed. #define SERVING_GRAPH_MAJOR_VERSION 0 -#define SERVING_GRAPH_MINOR_VERSION 1 +#define SERVING_GRAPH_MINOR_VERSION 2 #define SERVING_GRAPH_PATCH_VERSION 0 #define SERVING_STR_HELPER(x) #x diff --git a/secretflow_serving/ops/op_kernel.h b/secretflow_serving/ops/op_kernel.h index cd28830..491c7f6 100644 --- a/secretflow_serving/ops/op_kernel.h +++ b/secretflow_serving/ops/op_kernel.h @@ -24,6 +24,7 @@ #include "secretflow_serving/core/exception.h" #include "secretflow_serving/ops/node.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/he_mgm.h" #include "secretflow_serving/protos/op.pb.h" @@ -44,6 +45,12 @@ struct ComputeContext { // TODO: Session OpComputeInputs inputs; std::shared_ptr output; + + std::set other_party_ids; + std::string self_id; + std::string requester_id; + + he::HeKitMgm* he_kit_mgm = nullptr; }; class OpKernel { @@ -108,8 +115,6 @@ class OpKernel { DoCompute(ctx); - SERVING_ENFORCE_EQ(rows, ctx->output->num_rows(), - "rows of input and output be equal"); if (output_schema_->num_fields() > 0) { // only check when output schema is valid SERVING_ENFORCE( diff --git a/secretflow_serving/ops/phe_linear/BUILD.bazel b/secretflow_serving/ops/phe_linear/BUILD.bazel new file mode 100644 index 0000000..3194c71 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/BUILD.bazel @@ -0,0 +1,122 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:serving.bzl", "serving_cc_library", "serving_cc_test") + +package(default_visibility = ["//visibility:public"]) + +serving_cc_library( + name = "test_utils", + srcs = ["test_utils.cc"], + hdrs = ["test_utils.h"], + deps = [ + "//secretflow_serving/core:types", + "//secretflow_serving/util:he_mgm", + ], +) + +serving_cc_library( + name = "phe_2p_dot_product", + srcs = ["phe_2p_dot_product.cc"], + hdrs = ["phe_2p_dot_product.h"], + deps = [ + "//secretflow_serving/core:types", + "//secretflow_serving/ops:node_def_util", + "//secretflow_serving/ops:op_factory", + "//secretflow_serving/ops:op_kernel_factory", + "//secretflow_serving/util:he_mgm", + "//secretflow_serving/util:utils", + "@yacl//yacl/crypto/rand", + ], + alwayslink = True, +) + +serving_cc_test( + name = "phe_2p_dot_product_test", + srcs = ["phe_2p_dot_product_test.cc"], + deps = [ + ":phe_2p_dot_product", + ":test_utils", + "@yacl//yacl/utils:elapsed_timer", + ], +) + +serving_cc_library( + name = "phe_2p_decrypt_peer_y", + srcs = ["phe_2p_decrypt_peer_y.cc"], + hdrs = ["phe_2p_decrypt_peer_y.h"], + deps = [ + "//secretflow_serving/ops:node_def_util", + "//secretflow_serving/ops:op_factory", + "//secretflow_serving/ops:op_kernel_factory", + "//secretflow_serving/util:he_mgm", + "//secretflow_serving/util:utils", + ], + alwayslink = True, +) + +serving_cc_test( + name = "phe_2p_decrypt_peer_y_test", + srcs = ["phe_2p_decrypt_peer_y_test.cc"], + deps = [ + ":phe_2p_decrypt_peer_y", + ":test_utils", + ], +) + +serving_cc_library( + name = "phe_2p_merge_y", + srcs = ["phe_2p_merge_y.cc"], + hdrs = ["phe_2p_merge_y.h"], + deps = [ + "//secretflow_serving/core:link_func", + "//secretflow_serving/ops:node_def_util", + "//secretflow_serving/ops:op_factory", + "//secretflow_serving/ops:op_kernel_factory", + "//secretflow_serving/util:he_mgm", + "//secretflow_serving/util:utils", + ], + alwayslink = True, +) + +serving_cc_test( + name = "phe_2p_merge_y_test", + srcs = ["phe_2p_merge_y_test.cc"], + deps = [ + ":phe_2p_merge_y", + ":test_utils", + "@yacl//yacl/utils:elapsed_timer", + ], +) + +serving_cc_library( + name = "phe_2p_reduce", + srcs = ["phe_2p_reduce.cc"], + hdrs = ["phe_2p_reduce.h"], + deps = [ + "//secretflow_serving/ops:node_def_util", + "//secretflow_serving/ops:op_factory", + "//secretflow_serving/ops:op_kernel_factory", + ], + alwayslink = True, +) + +serving_cc_test( + name = "phe_2p_reduce_test", + srcs = ["phe_2p_reduce_test.cc"], + deps = [ + ":phe_2p_reduce", + ":test_utils", + ], +) diff --git a/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.cc b/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.cc new file mode 100644 index 0000000..ac4883f --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.cc @@ -0,0 +1,92 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.h" + +#include "secretflow_serving/ops/node_def_util.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/he_mgm.h" + +namespace secretflow::serving::op::phe_2p { + +PheDecryptPeerY::PheDecryptPeerY(OpKernelOptions opts) + : OpKernel(std::move(opts)) { + // feature name + partial_y_col_name_ = + GetNodeAttr(opts_.node_def, "partial_y_col_name"); + decrypted_col_name_ = + GetNodeAttr(opts_.node_def, "decrypted_col_name"); + + BuildInputSchema(); + BuildOutputSchema(); +} + +void PheDecryptPeerY::DoCompute(ComputeContext* ctx) { + SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() == 1, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->other_party_ids.size() == 1, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->he_kit_mgm, errors::ErrorCode::LOGIC_ERROR); + + // get peer y + auto peer_y_array = ctx->inputs.front().front()->column(0); + auto peer_y_buf = + std::static_pointer_cast(peer_y_array)->Value(0); + auto peer_y_matrix = heu_matrix::CMatrix::LoadFrom(peer_y_buf); + + // decrypt + auto matrix_decryptor = ctx->he_kit_mgm->GetLocalMatrixDecryptor(); + auto p_y_matrix = matrix_decryptor->Decrypt(peer_y_matrix); + auto p_y_buf = p_y_matrix.Serialize(); + + std::shared_ptr ye_array; + arrow::BinaryBuilder ye_builder; + SERVING_CHECK_ARROW_STATUS( + ye_builder.Append(p_y_buf.data(), p_y_buf.size())); + SERVING_CHECK_ARROW_STATUS(ye_builder.Finish(&ye_array)); + + ctx->output = MakeRecordBatch(output_schema_, 1, {ye_array}); +} + +void PheDecryptPeerY::BuildInputSchema() { + // build input schema + input_schema_list_.emplace_back( + arrow::schema({arrow::field(partial_y_col_name_, arrow::binary())})); +} + +void PheDecryptPeerY::BuildOutputSchema() { + // build output schema + output_schema_ = + arrow::schema({arrow::field(decrypted_col_name_, arrow::binary())}); +} + +REGISTER_OP_KERNEL(PHE_2P_DECRYPT_PEER_Y, PheDecryptPeerY) +REGISTER_OP(PHE_2P_DECRYPT_PEER_Y, "0.0.1", + "Two-party computation operator. Decrypt the obfuscated partial_y " + "and add a random number.") + .StringAttr("partial_y_col_name", + "The name of the partial_y(which can be decrypt by self) " + "column in the input", + false, false) + .StringAttr("decrypted_col_name", + "The name of the decrypted result column in the output", false, + false) + .Input("crypted_data", "Input feature table") + .Output("decrypted_data", + "Decrypted partial_y with the added random number."); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.h b/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.h new file mode 100644 index 0000000..ed5778c --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.h @@ -0,0 +1,37 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "secretflow_serving/ops/op_kernel.h" + +namespace secretflow::serving::op::phe_2p { + +class PheDecryptPeerY : public OpKernel { + public: + explicit PheDecryptPeerY(OpKernelOptions opts); + + void DoCompute(ComputeContext* ctx) override; + + protected: + void BuildInputSchema() override; + + void BuildOutputSchema() override; + + private: + std::string partial_y_col_name_; + std::string decrypted_col_name_; +}; + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y_test.cc b/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y_test.cc new file mode 100644 index 0000000..12805b7 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y_test.cc @@ -0,0 +1,145 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_decrypt_peer_y.h" + +#include "arrow/ipc/api.h" +#include "gtest/gtest.h" + +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/ops/phe_linear/test_utils.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op::phe_2p { + +class PheDecryptPeerYTest : public ::testing::Test { + protected: + void SetUp() override { + he_kit_mgm_ = std::make_unique(); + he_kit_mgm_->InitLocalKit(alice_kit_.GetPublicKey()->Serialize(), + alice_kit_.GetSecretKey()->Serialize(), 1e6); + he_kit_mgm_->InitDstKit("bob", bob_kit_.GetPublicKey()->Serialize()); + } + void TearDown() override {} + + protected: + std::unique_ptr he_kit_mgm_; + + heu_phe::HeKit alice_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_alice_kit_ = heu_matrix::HeKit(alice_kit_); + heu_phe::HeKit bob_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_bob_kit_ = heu_matrix::HeKit(bob_kit_); +}; + +TEST_F(PheDecryptPeerYTest, Works) { + std::string json_content = R"JSON( +{ + "name": "test_node", + "op": "PHE_2P_DECRYPT_PEER_Y", + "attr_values": { + "partial_y_col_name": { + "s": "partial_y", + }, + "decrypted_col_name": { + "s": "decrypted_y", + } + }, + "op_version": "0.0.1", +} +)JSON"; + NodeDef node_def; + JsonToPb(json_content, &node_def); + + auto compute_encoder = he_kit_mgm_->GetEncoder(he::kFeatureScale * + he_kit_mgm_->GetEncodeScale()); + auto base_encoder = he_kit_mgm_->GetEncoder(he_kit_mgm_->GetEncodeScale()); + + // build input&output schema + auto expect_input_schema = + arrow::schema({arrow::field("partial_y", arrow::binary())}); + auto expect_output_schema = + arrow::schema({arrow::field("decrypted_y", arrow::binary())}); + + // create kernel + auto mock_node = std::make_shared(std::move(node_def)); + ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); + ASSERT_FALSE(mock_node->GetOpDef()->tag().mergeable()); + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + + // check input schema + ASSERT_EQ(kernel->GetInputsNum(), mock_node->GetOpDef()->inputs_size()); + const auto& input_schema_list = kernel->GetAllInputSchema(); + ASSERT_EQ(input_schema_list.size(), kernel->GetInputsNum()); + for (size_t i = 0; i < input_schema_list.size(); ++i) { + const auto& input_schema = input_schema_list[i]; + ASSERT_TRUE(input_schema->Equals(expect_input_schema)); + } + // check output schema + auto output_schema = kernel->GetOutputSchema(); + std::cout << "real output schema: " << output_schema->ToString() << std::endl; + + ASSERT_TRUE(output_schema->Equals(expect_output_schema)); + + // build input + // generate reduce output + auto bob_y = test::GenRawMatrix(2, 1, 2); + auto p_bob_y = test::EncodeMatrix(bob_y, compute_encoder.get()); + auto c_bob_y = m_alice_kit_.GetEncryptor()->Encrypt(p_bob_y); + + ComputeContext compute_ctx; + compute_ctx.other_party_ids = {"bob"}; + compute_ctx.self_id = "alice"; + compute_ctx.he_kit_mgm = he_kit_mgm_.get(); + + // bob y + std::shared_ptr bob_c_y_array; + arrow::BinaryBuilder bob_c_y_builder; + auto c_bob_y_buf = c_bob_y.Serialize(); + SERVING_CHECK_ARROW_STATUS( + bob_c_y_builder.Append(c_bob_y_buf.data(), c_bob_y_buf.size())); + SERVING_CHECK_ARROW_STATUS(bob_c_y_builder.Finish(&bob_c_y_array)); + + compute_ctx.inputs.emplace_back( + std::vector>{ + MakeRecordBatch(expect_input_schema, 1, {bob_c_y_array})}); + + kernel->Compute(&compute_ctx); + + // check output schema + ASSERT_TRUE(compute_ctx.output); + ASSERT_TRUE(compute_ctx.output->schema()->Equals(output_schema)); + ASSERT_EQ(compute_ctx.output->num_rows(), 1); + + // check decrypted_ye col + auto expect_ye = bob_y; + auto decrypted_ye_col = compute_ctx.output->GetColumnByName("decrypted_y"); + auto decrypted_ye_buf = + std::static_pointer_cast(decrypted_ye_col)->Value(0); + auto decrypted_ye_matrix = heu_matrix::PMatrix::LoadFrom(decrypted_ye_buf); + ASSERT_EQ(decrypted_ye_matrix.rows(), 2); + for (int i = 0; i < decrypted_ye_matrix.rows(); ++i) { + std::cout << i << " expect_ye: " << expect_ye(i, 0) << std::endl; + std::cout << i << " actual_ye: " << decrypted_ye_matrix(i, 0) << std::endl; + + ASSERT_EQ(compute_encoder->Decode(decrypted_ye_matrix(i, 0)), + expect_ye(i, 0)); + } +} + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_dot_product.cc b/secretflow_serving/ops/phe_linear/phe_2p_dot_product.cc new file mode 100644 index 0000000..357714b --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_dot_product.cc @@ -0,0 +1,314 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_dot_product.h" + +#include + +#include "yacl/crypto/rand/rand.h" + +#include "secretflow_serving/ops/node_def_util.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" + +#include "secretflow_serving/protos/data_type.pb.h" + +namespace secretflow::serving::op::phe_2p { + +namespace { + +heu_matrix::PMatrix TableToPMatrix( + const std::shared_ptr& table, + const heu_phe::PlainEncoder* encoder, const std::string& offset_col_name) { + int rows = table->num_rows(); + int cols = table->num_columns(); + + if (!offset_col_name.empty()) { + SERVING_ENFORCE_EQ(table->column_name(cols - 1), offset_col_name); + cols -= 1; + } + + heu_matrix::PMatrix plain_matrix(rows, cols); + for (int i = 0; i < cols; ++i) { + auto double_array = CastToDoubleArray(table->column(i)); + for (int64_t j = 0; j < double_array->length(); ++j) { + plain_matrix(j, i) = encoder->Encode(double_array->Value(j)); + } + } + + return plain_matrix; +} + +inline void BuildBinaryArray(const ::yacl::Buffer& buf, + std::shared_ptr* array) { + arrow::BinaryBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.Append(buf.data(), buf.size())); + SERVING_CHECK_ARROW_STATUS(builder.Finish(array)); +} + +} // namespace + +PheDotProduct::PheDotProduct(OpKernelOptions opts) + : OpKernel(std::move(opts)), c_w_matrix_(1, 1) { + // feature name + feature_name_list_ = GetNodeAttr>( + opts_.node_def, *opts_.op_def, "feature_names"); + std::set f_name_set; + for (const auto& feature_name : feature_name_list_) { + SERVING_ENFORCE(f_name_set.emplace(feature_name).second, + errors::ErrorCode::LOGIC_ERROR, + "found duplicate feature name:{}", feature_name); + } + // feature types + feature_type_list_ = GetNodeAttr>( + opts_.node_def, *opts_.op_def, "feature_types"); + SERVING_ENFORCE_EQ(feature_name_list_.size(), feature_type_list_.size(), + "attr:feature_names size={} does not match " + "attr:feature_types size={}, node:{}, op:{}", + feature_name_list_.size(), feature_type_list_.size(), + opts_.node_def.name(), opts_.node_def.op()); + // offset + offset_col_name_ = GetNodeAttr(opts_.node_def, *opts_.op_def, + "offset_col_name"); + if (!offset_col_name_.empty()) { + SERVING_ENFORCE( + !feature_name_list_.empty(), errors::ErrorCode::LOGIC_ERROR, + "attr:offset_col_name is set, but get empty attr:feature_names"); + SERVING_ENFORCE_EQ(feature_name_list_.back(), offset_col_name_, + "the offset column name must be placed at the end of " + "the feature list."); + } + + result_col_name_ = + GetNodeAttr(opts_.node_def, "result_col_name"); + rand_number_col_name_ = + GetNodeAttr(opts_.node_def, "rand_number_col_name"); + + auto intercept_bytes = GetNodeBytesAttr( + opts_.node_def, *opts_.op_def, "intercept_ciphertext"); + if (!intercept_bytes.empty()) { + has_intercept_ = true; + try { + c_intercept_.Deserialize(intercept_bytes); + } catch (const std::exception& e) { + SPDLOG_WARN( + "failed to load intercept ciphertext, reason: {}. now try load " + "intercept matrix " + "ciphertext.", + e.what()); + + heu_matrix::CMatrix c_intercept_matrix(1, 1); + try { + c_intercept_matrix = heu_matrix::CMatrix::LoadFrom(intercept_bytes); + } catch (const std::exception& ne) { + SPDLOG_WARN("failed to load intercept ciphertext matrix, reason: {}", + ne.what()); + SERVING_THROW( + errors::ErrorCode::UNEXPECTED_ERROR, + "failed to load intercept for node({}), reason [{}] or [{}]", + opts_.node_def.name(), e.what(), ne.what()); + } + SERVING_ENFORCE_EQ(c_intercept_matrix.size(), 1); + c_intercept_ = c_intercept_matrix(0, 0); + } + } + + if (feature_name_list_.empty()) { + no_feature_ = true; + } else if (feature_name_list_.size() == 1 && !offset_col_name_.empty()) { + no_feature_ = true; + } else { + auto encrypted_weight_bytes = GetNodeBytesAttr( + opts_.node_def, "feature_weights_ciphertext"); + c_w_matrix_ = heu_matrix::CMatrix::LoadFrom(encrypted_weight_bytes); + + int32_t compute_feature_num = offset_col_name_.empty() + ? feature_name_list_.size() + : feature_name_list_.size() - 1; + SERVING_ENFORCE_EQ(c_w_matrix_.ndim(), 1); + SERVING_ENFORCE_EQ(c_w_matrix_.size(), compute_feature_num, + "The shape of weights ciphertext matrix mismatch with " + "the compute feature number, {} vs {}", + c_w_matrix_.size(), compute_feature_num); + } + + BuildInputSchema(); + BuildOutputSchema(); +} + +void PheDotProduct::DoCompute(ComputeContext* ctx) { + SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() == 1, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->other_party_ids.size() == 1, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->he_kit_mgm, errors::ErrorCode::LOGIC_ERROR); + + auto remote_party_id = *(ctx->other_party_ids.begin()); + + auto feature_encoder = ctx->he_kit_mgm->GetEncoder(he::kFeatureScale); + auto compute_encoder = ctx->he_kit_mgm->GetEncoder( + he::kFeatureScale * ctx->he_kit_mgm->GetEncodeScale()); + + // 1. gen rand num E + auto rand_num = yacl::crypto::FastRandU64(); + heu_phe::Plaintext p_e(ctx->he_kit_mgm->GetSchemaType(), rand_num); + + // 2.1. W * X + offset + const auto& dst_m_evaluator = + ctx->he_kit_mgm->GetDstMatrixEvaluator(remote_party_id); + const heu_phe::Ciphertext zero_c = + ctx->he_kit_mgm->GetDstEncryptor(remote_party_id)->EncryptZero(); + heu_matrix::CMatrix c_wx_matrix(ctx->inputs.front().front()->num_rows(), 1); + if (no_feature_) { + for (int32_t i = 0; i < c_wx_matrix.rows(); ++i) { + c_wx_matrix(i, 0) = zero_c; + } + } else { + auto p_x_matrix = TableToPMatrix(ctx->inputs.front().front(), + feature_encoder.get(), offset_col_name_); + c_wx_matrix = dst_m_evaluator->MatMul(p_x_matrix, c_w_matrix_); + } + // offset + if (!offset_col_name_.empty()) { + auto offset_array = + ctx->inputs.front().front()->GetColumnByName(offset_col_name_); + auto double_array = CastToDoubleArray(offset_array); + heu_matrix::PMatrix p_offset_matrix(double_array->length(), 1); + for (int64_t i = 0; i < double_array->length(); ++i) { + p_offset_matrix(i, 0) = compute_encoder->Encode(double_array->Value(i)); + } + c_wx_matrix = dst_m_evaluator->Add(c_wx_matrix, p_offset_matrix); + } + + // 2.2 W * X + offset - E + heu_matrix::PMatrix p_e_matrix(c_wx_matrix.shape()); + for (int i = 0; i < p_e_matrix.rows(); ++i) { + for (int j = 0; j < p_e_matrix.cols(); ++j) { + p_e_matrix(i, j) = p_e; + } + } + auto c_wxe_matrix = dst_m_evaluator->Sub(c_wx_matrix, p_e_matrix); + + // 2.3 W * X + offset - E + I + auto result_matrix = c_wxe_matrix; + if (has_intercept_) { + // intercept ciphertext * kFeatureScale + const auto& evaluator = ctx->he_kit_mgm->GetDstEvaluator(remote_party_id); + auto c_i = evaluator->Mul(c_intercept_, feature_encoder->Encode(1)); + + heu_matrix::CMatrix c_i_matrix(c_wxe_matrix.shape()); + for (int i = 0; i < c_i_matrix.rows(); ++i) { + for (int j = 0; j < c_i_matrix.cols(); ++j) { + c_i_matrix(i, j) = c_i; + } + } + result_matrix = dst_m_evaluator->Add(c_wxe_matrix, c_i_matrix); + } + + // build result array + auto result_buf = result_matrix.Serialize(); + std::shared_ptr wx_array; + BuildBinaryArray(result_buf, &wx_array); + + // build rand num array + std::shared_ptr e_array; + if (ctx->requester_id == ctx->self_id) { + // no need encrypt + auto p_e_buf = p_e.Serialize(); + BuildBinaryArray(p_e_buf, &e_array); + } else { + auto c_e = ctx->he_kit_mgm->GetLocalEncryptor()->Encrypt(p_e); + auto c_e_buf = c_e.Serialize(); + BuildBinaryArray(c_e_buf, &e_array); + } + + // build party id array + std::shared_ptr p_array; + arrow::StringBuilder p_builder; + SERVING_CHECK_ARROW_STATUS(p_builder.Append(ctx->self_id)); + SERVING_CHECK_ARROW_STATUS(p_builder.Finish(&p_array)); + + ctx->output = + MakeRecordBatch(output_schema_, 1, {wx_array, e_array, p_array}); +} + +void PheDotProduct::BuildInputSchema() { + // build input schema + std::vector> fields; + for (size_t i = 0; i < feature_name_list_.size(); ++i) { + auto data_type = DataTypeToArrowDataType(feature_type_list_[i]); + SERVING_ENFORCE( + arrow::is_numeric(data_type->id()), errors::INVALID_ARGUMENT, + "feature type must be numeric, get:{}", feature_type_list_[i]); + fields.emplace_back(arrow::field(feature_name_list_[i], data_type)); + } + + if (!offset_col_name_.empty()) { + SERVING_ENFORCE_EQ( + fields.rbegin()->get()->name(), offset_col_name_, + "offset column({}) must be the last column of the input schema.", + offset_col_name_); + } + + input_schema_list_.emplace_back(arrow::schema(std::move(fields))); +} + +void PheDotProduct::BuildOutputSchema() { + // build output schema + output_schema_ = + arrow::schema({arrow::field(result_col_name_, arrow::binary()), + arrow::field(rand_number_col_name_, arrow::binary()), + arrow::field("party", arrow::utf8())}); +} + +REGISTER_OP_KERNEL(PHE_2P_DOT_PRODUCT, PheDotProduct) +REGISTER_OP( + PHE_2P_DOT_PRODUCT, "0.0.1", + "Two-party computation operator. Load the encrypted feature weights, " + "compute their dot product with the " + "feature values, and add random noise to the result for obfuscation. Only " + "supports computation between two parties, with the weights being " + "encrypted using the other party's key.") + .StringAttr("feature_names", + "List of feature names. Note that if there is an offset " + "column, it needs to be the last one in the list", + true, true, std::vector()) + .BytesAttr("feature_weights_ciphertext", + "feature weight ciphertext matrix bytes", false, true, + std::string()) + .StringAttr("feature_types", + "List of input feature data types. Optional " + "value: DT_UINT8, " + "DT_INT8, DT_UINT16, DT_INT16, DT_UINT32, DT_INT32, DT_UINT64, " + "DT_INT64, DT_FLOAT, DT_DOUBLE", + true, true, std::vector()) + .BytesAttr("intercept_ciphertext", + "Intercept ciphertext bytes or matrix bytes", false, true, + std::string()) + .StringAttr("offset_col_name", + "The name of the offset column(feature) in the input", false, + true, std::string()) + .StringAttr( + "result_col_name", + "The name of the calculation result(partial_y) column in the output", + false, false) + .StringAttr("rand_number_col_name", + "The name of the generated rand number column in the output", + false, false) + .Input("features", "Input features") + .Output("partial_y", "Calculation results"); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_dot_product.h b/secretflow_serving/ops/phe_linear/phe_2p_dot_product.h new file mode 100644 index 0000000..a5abf4b --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_dot_product.h @@ -0,0 +1,48 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "secretflow_serving/ops/op_kernel.h" +#include "secretflow_serving/util/he_mgm.h" + +namespace secretflow::serving::op::phe_2p { + +class PheDotProduct : public OpKernel { + public: + explicit PheDotProduct(OpKernelOptions opts); + + void DoCompute(ComputeContext* ctx) override; + + protected: + void BuildInputSchema() override; + + void BuildOutputSchema() override; + + private: + std::vector feature_name_list_; + std::vector feature_type_list_; + + heu_matrix::CMatrix c_w_matrix_; + heu_phe::Ciphertext c_intercept_; + + std::string offset_col_name_; + std::string result_col_name_; + std::string rand_number_col_name_; + + bool no_feature_ = false; + bool has_intercept_ = false; +}; + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_dot_product_test.cc b/secretflow_serving/ops/phe_linear/phe_2p_dot_product_test.cc new file mode 100644 index 0000000..007c007 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_dot_product_test.cc @@ -0,0 +1,291 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_dot_product.h" + +#include "gtest/gtest.h" +#include "yacl/utils/elapsed_timer.h" + +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/ops/phe_linear/test_utils.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op::phe_2p { + +struct Param { + size_t feature_num; + bool has_offset; + bool has_intercept; + bool self_request; +}; + +class PheDotProductTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + he_kit_mgm_ = std::make_unique(); + he_kit_mgm_->InitLocalKit(kit_.GetPublicKey()->Serialize(), + kit_.GetSecretKey()->Serialize(), 1e6); + he_kit_mgm_->InitDstKit("bob", remote_kit_.GetPublicKey()->Serialize()); + } + void TearDown() override {} + + protected: + std::unique_ptr he_kit_mgm_; + + heu_phe::HeKit kit_ = heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_kit_ = heu_matrix::HeKit(kit_); + heu_phe::HeKit remote_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_remote_kit_ = heu_matrix::HeKit(remote_kit_); +}; + +TEST_P(PheDotProductTest, Works) { + auto param = GetParam(); + size_t row_num = 2; + + std::string json_content = R"JSON( +{ + "name": "test_node", + "op": "PHE_2P_DOT_PRODUCT", + "attr_values": { + "result_col_name": { + "s": "wxe", + }, + "rand_number_col_name": { + "s": "rand", + } + }, + "op_version": "0.0.1", +} +)JSON"; + NodeDef node_def; + JsonToPb(json_content, &node_def); + std::vector feature_names; + for (size_t i = 0; i < param.feature_num; ++i) { + feature_names.emplace_back("f_" + std::to_string(i)); + } + std::vector feature_types(feature_names.size(), "DT_DOUBLE"); + auto compute_feature_num = feature_names.size(); + + if (!feature_names.empty()) { + AttrValue feature_names_value; + feature_names_value.mutable_ss()->mutable_data()->Assign( + feature_names.begin(), feature_names.end()); + node_def.mutable_attr_values()->insert( + {"feature_names", feature_names_value}); + + AttrValue feature_types_value; + feature_types_value.mutable_ss()->mutable_data()->Assign( + feature_types.begin(), feature_types.end()); + node_def.mutable_attr_values()->insert( + {"feature_types", feature_types_value}); + } + if (param.has_offset) { + AttrValue offset_col_name_value; + offset_col_name_value.set_s(feature_names.back()); + node_def.mutable_attr_values()->insert( + {"offset_col_name", offset_col_name_value}); + compute_feature_num--; + } + + auto compute_encoder = he_kit_mgm_->GetEncoder(he::kFeatureScale * + he_kit_mgm_->GetEncodeScale()); + auto base_encoder = he_kit_mgm_->GetEncoder(he_kit_mgm_->GetEncodeScale()); + + // generate weight ciphertext + Double::Matrix weight_matrix; + if (compute_feature_num > 0) { + weight_matrix = test::GenRawMatrix(compute_feature_num, 1, 1); + AttrValue feature_weights_ciphertext; + { + auto p_w_m = test::EncodeMatrix(weight_matrix, base_encoder.get()); + auto c_w_m = m_remote_kit_.GetEncryptor()->Encrypt(p_w_m); + auto c_w_buf = c_w_m.Serialize(); + feature_weights_ciphertext.set_by( + std::string(c_w_buf.data(), c_w_buf.size())); + } + node_def.mutable_attr_values()->insert( + {"feature_weights_ciphertext", feature_weights_ciphertext}); + } + + // generate intercept + double intercept = 0; + if (param.has_intercept) { + intercept = 1.5; + AttrValue intercept_ciphertext; + { + auto p_i = base_encoder->Encode(intercept); + auto c_i = remote_kit_.GetEncryptor()->Encrypt(p_i); + auto c_i_buf = c_i.Serialize(); + intercept_ciphertext.set_by( + std::string(c_i_buf.data(), c_i_buf.size())); + } + node_def.mutable_attr_values()->insert( + {"intercept_ciphertext", intercept_ciphertext}); + } + + // build expect schema + std::vector> input_fields; + for (const auto& n : feature_names) { + input_fields.emplace_back(arrow::field(n, arrow::float64())); + } + auto expect_input_schema = arrow::schema(input_fields); + auto expect_output_schema = + arrow::schema({arrow::field("wxe", arrow::binary()), + arrow::field("rand", arrow::binary()), + arrow::field("party", arrow::utf8())}); + + auto mock_node = std::make_shared(std::move(node_def)); + ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); + + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + + // check input schema + ASSERT_EQ(kernel->GetInputsNum(), mock_node->GetOpDef()->inputs_size()); + const auto& input_schema_list = kernel->GetAllInputSchema(); + ASSERT_EQ(input_schema_list.size(), kernel->GetInputsNum()); + for (size_t i = 0; i < input_schema_list.size(); ++i) { + const auto& input_schema = input_schema_list[i]; + ASSERT_TRUE(input_schema->Equals(expect_input_schema)); + } + // check output schema + auto output_schema = kernel->GetOutputSchema(); + ASSERT_TRUE(output_schema->Equals(expect_output_schema)); + + // build input + ComputeContext compute_ctx; + compute_ctx.other_party_ids = {"bob"}; + compute_ctx.self_id = "alice"; + if (param.self_request) { + compute_ctx.requester_id = "alice"; + } else { + compute_ctx.requester_id = "bob"; + } + compute_ctx.he_kit_mgm = he_kit_mgm_.get(); + Double::Matrix input_m; + if (compute_feature_num > 0 || param.has_offset) { + input_m.resize(row_num, feature_names.size()); + std::vector> arrays; + for (int j = 0; j < input_m.cols(); ++j) { + std::shared_ptr array; + arrow::DoubleBuilder builder; + for (int i = 0; i < input_m.rows(); ++i) { + input_m(i, j) = i + 1 + (j + 1) * 0.1; + SERVING_CHECK_ARROW_STATUS(builder.Append(input_m(i, j))); + } + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + arrays.emplace_back(array); + } + compute_ctx.inputs.emplace_back( + std::vector>{MakeRecordBatch( + arrow::schema(input_fields), input_m.rows(), arrays)}); + } else { + // no feature, no offset, mock input + std::shared_ptr array; + arrow::DoubleBuilder builder; + for (size_t i = 0; i < row_num; ++i) { + SERVING_CHECK_ARROW_STATUS(builder.Append(i)); + } + SERVING_CHECK_ARROW_STATUS(builder.Finish(&array)); + compute_ctx.inputs.emplace_back( + std::vector>{MakeRecordBatch( + arrow::schema({arrow::field("mock", arrow::float64())}), row_num, + {array})}); + } + + yacl::ElapsedTimer timer; + + kernel->Compute(&compute_ctx); + + std::cout << "compute time: " << timer.CountMs() << "\n"; + + // check output + ASSERT_TRUE(compute_ctx.output); + ASSERT_TRUE(compute_ctx.output->schema()->Equals(output_schema)); + ASSERT_EQ(compute_ctx.output->num_rows(), 1); + + auto rand_array = compute_ctx.output->GetColumnByName("rand"); + heu_phe::Plaintext p_e; + if (param.self_request) { + p_e.Deserialize( + std::static_pointer_cast(rand_array)->Value(0)); + } else { + heu_phe::Ciphertext c_e; + c_e.Deserialize( + std::static_pointer_cast(rand_array)->Value(0)); + p_e = kit_.GetDecryptor()->Decrypt(c_e); + } + + std::cout << "plain rand: " << p_e << std::endl; + + // expect result + heu_phe::Plaintext p_one(kit_.GetSchemaType(), 1); + heu_phe::Plaintext p_minus_one(kit_.GetSchemaType(), -1); + auto wxe_array = compute_ctx.output->GetColumnByName("wxe"); + auto c_real_wxe_matrix = heu_matrix::CMatrix::LoadFrom( + std::static_pointer_cast(wxe_array)->Value(0)); + auto p_real_wxe_matrix = + m_remote_kit_.GetDecryptor()->Decrypt(c_real_wxe_matrix); + if (compute_feature_num > 0) { + Double::Matrix weight_with_offset_matrix; + if (param.has_offset) { + weight_with_offset_matrix.resize(weight_matrix.rows() + 1, + weight_matrix.cols()); + weight_with_offset_matrix << weight_matrix, + Double::RowVec::Constant(1, 1, 1.0); + } else { + weight_with_offset_matrix = weight_matrix; + } + Double::ColVec expect_score = input_m * weight_with_offset_matrix; + expect_score.array() += intercept; + auto p_expect_score_matrix = + test::EncodeMatrix(expect_score, compute_encoder.get()); + + ASSERT_EQ(p_real_wxe_matrix.rows(), expect_score.rows()); + ASSERT_EQ(p_real_wxe_matrix.cols(), expect_score.cols()); + for (int i = 0; i < p_real_wxe_matrix.rows(); ++i) { + for (int j = 0; j < p_real_wxe_matrix.cols(); ++j) { + auto p_expect = + remote_kit_.GetEvaluator()->Sub(p_expect_score_matrix(i, j), p_e); + ASSERT_TRUE((p_real_wxe_matrix(i, j) - p_expect) <= p_one && + (p_real_wxe_matrix(i, j) - p_expect) >= p_minus_one); + } + } + } else { + for (size_t i = 0; i < row_num; ++i) { + double expect_score = intercept; + if (param.has_offset) { + expect_score += input_m(i, 0); + } + auto p_expect_score = compute_encoder->Encode(expect_score); + p_expect_score = remote_kit_.GetEvaluator()->Sub(p_expect_score, p_e); + ASSERT_TRUE((p_real_wxe_matrix(i, 0) - p_expect_score) <= p_one && + (p_real_wxe_matrix(i, 0) - p_expect_score) >= p_minus_one); + } + } +} + +INSTANTIATE_TEST_SUITE_P(PheDotProductTestSuite, PheDotProductTest, + ::testing::Values(Param{4, true, true, false}, + Param{4, false, true, false}, + Param{4, true, false, false}, + Param{4, false, false, false}, + Param{4, true, true, true}, + Param{1, true, true, false}, + Param{0, false, true, false})); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_merge_y.cc b/secretflow_serving/ops/phe_linear/phe_2p_merge_y.cc new file mode 100644 index 0000000..6de90da --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_merge_y.cc @@ -0,0 +1,153 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_merge_y.h" + +#include "secretflow_serving/core/link_func.h" +#include "secretflow_serving/ops/node_def_util.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/he_mgm.h" + +namespace secretflow::serving::op::phe_2p { + +PheMergeY::PheMergeY(OpKernelOptions opts) : OpKernel(std::move(opts)) { + decrypted_y_col_name_ = + GetNodeAttr(opts_.node_def, "decrypted_y_col_name"); + crypted_y_col_name_ = + GetNodeAttr(opts_.node_def, "crypted_y_col_name"); + score_col_name_ = GetNodeAttr(opts_.node_def, "score_col_name"); + + auto link_function_name = + GetNodeAttr(opts_.node_def, "link_function"); + link_function_ = ParseLinkFuncType(link_function_name); + + // optional attr + yhat_scale_ = + GetNodeAttr(opts_.node_def, *opts_.op_def, "yhat_scale"); + exp_iters_ = GetNodeAttr(opts_.node_def, *opts_.op_def, "exp_iters"); + CheckLinkFuncAragsValid(link_function_, exp_iters_); + + BuildInputSchema(); + BuildOutputSchema(); +} + +void PheMergeY::DoCompute(ComputeContext* ctx) { + SERVING_ENFORCE(ctx->inputs.size() == 2, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs[0].size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs[1].size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->other_party_ids.size() == 1, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->he_kit_mgm, errors::ErrorCode::LOGIC_ERROR); + + const auto& decrypted_data = ctx->inputs[0].front(); + const auto& crypted_data = ctx->inputs[1].front(); + + // add self rand number + auto peer_ye_array = crypted_data->column(0); + auto peer_ye_buf = + std::static_pointer_cast(peer_ye_array)->Value(0); + auto c_peer_ye_matrix = heu_matrix::CMatrix::LoadFrom(peer_ye_buf); + + // decrypt peer y + auto matrix_decryptor = ctx->he_kit_mgm->GetLocalMatrixDecryptor(); + auto p_peer_ye_matrix = matrix_decryptor->Decrypt(c_peer_ye_matrix); + + // self y + auto self_ye_array = decrypted_data->column(0); + auto self_ye_buf = + std::static_pointer_cast(self_ye_array)->Value(0); + auto p_self_ye_matrix = heu_matrix::PMatrix::LoadFrom(self_ye_buf); + + // self_ye + peer_ye + auto matrix_evaluator = ctx->he_kit_mgm->GetLocalMatrixEvaluator(); + auto p_score_matrix = + matrix_evaluator->Add(p_self_ye_matrix, p_peer_ye_matrix); + + auto compute_encoder = ctx->he_kit_mgm->GetEncoder( + he::kFeatureScale * ctx->he_kit_mgm->GetEncodeScale()); + + std::shared_ptr score_array; + arrow::DoubleBuilder score_builder; + for (int i = 0; i < p_score_matrix.rows(); ++i) { + auto score = compute_encoder->Decode(p_score_matrix(i, 0)); + score = ApplyLinkFunc(score, link_function_, exp_iters_) * yhat_scale_; + SERVING_CHECK_ARROW_STATUS(score_builder.Append(score)); + } + SERVING_CHECK_ARROW_STATUS(score_builder.Finish(&score_array)); + + ctx->output = + MakeRecordBatch(output_schema_, score_array->length(), {score_array}); +} + +void PheMergeY::BuildInputSchema() { + // build input schema + input_schema_list_.emplace_back( + arrow::schema({arrow::field(decrypted_y_col_name_, arrow::binary())})); + input_schema_list_.emplace_back( + arrow::schema({arrow::field(crypted_y_col_name_, arrow::binary())})); +} + +void PheMergeY::BuildOutputSchema() { + // build output schema + output_schema_ = + arrow::schema({arrow::field(score_col_name_, arrow::float64())}); +} + +REGISTER_OP_KERNEL(PHE_2P_MERGE_Y, PheMergeY) +REGISTER_OP( + PHE_2P_MERGE_Y, "0.0.1", + "Two-party computation operator. Merge the obfuscated partial_y decrypted " + "by the peer party with the " + "partial_y based on self own key to obtain the final prediction score.") + .Returnable() + .StringAttr("decrypted_y_col_name", + "The name of the decrypted partial_y column in the first input", + false, false) + .StringAttr("crypted_y_col_name", + "The name of the crypted partial_y column in the second input", + false, false) + .StringAttr("score_col_name", "The name of the score column in the output", + false, false) + .DoubleAttr( + "yhat_scale", + "In order to prevent value overflow, GLM training is performed on the " + "scaled y label. So in the prediction process, you need to enlarge " + "yhat back to get the real predicted value, `yhat = yhat_scale * " + "link(X * W)`", + false, true, 1.0) + .StringAttr( + "link_function", + "Type of link function, defined in " + "`secretflow_serving/protos/link_function.proto`. Optional value: " + "LF_EXP, LF_EXP_TAYLOR, " + "LF_RECIPROCAL, " + "LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, " + "LF_SIGMOID_GA, " + "LF_SIGMOID_T1, LF_SIGMOID_T3, " + "LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, " + "LF_SIGMOID_SEG3, " + "LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS", + false, false) + .Int32Attr("exp_iters", + "Number of iterations of `exp` approximation, valid when " + "`link_function` set `LF_EXP_TAYLOR`", + false, true, 0) + .Input("decrypted_data", + "The decrypted data output by `PHE_2P_DECRYPT_PEER_Y`") + .Input("crypted_data", "The crypted data selected by `PHE_2P_REDUCE`") + .Output("score", "The final linear predict score."); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_merge_y.h b/secretflow_serving/ops/phe_linear/phe_2p_merge_y.h new file mode 100644 index 0000000..6737c09 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_merge_y.h @@ -0,0 +1,44 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "secretflow_serving/ops/op_kernel.h" + +#include "secretflow_serving/protos/link_function.pb.h" + +namespace secretflow::serving::op::phe_2p { + +class PheMergeY : public OpKernel { + public: + explicit PheMergeY(OpKernelOptions opts); + + void DoCompute(ComputeContext* ctx) override; + + protected: + void BuildInputSchema() override; + + void BuildOutputSchema() override; + + private: + std::string score_col_name_; + std::string decrypted_y_col_name_; + std::string crypted_y_col_name_; + + double yhat_scale_ = 1.0; + LinkFunctionType link_function_; + int32_t exp_iters_ = 0; +}; + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_merge_y_test.cc b/secretflow_serving/ops/phe_linear/phe_2p_merge_y_test.cc new file mode 100644 index 0000000..f7ce4a6 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_merge_y_test.cc @@ -0,0 +1,208 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_merge_y.h" + +#include "gtest/gtest.h" +#include "yacl/utils/elapsed_timer.h" + +#include "secretflow_serving/core/link_func.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/ops/phe_linear/test_utils.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op::phe_2p { + +struct Param { + std::string link_func; + double yhat_scale; + int32_t exp_iters = 0; +}; + +class PheMergeYTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + he_kit_mgm_ = std::make_unique(); + he_kit_mgm_->InitLocalKit(alice_kit_.GetPublicKey()->Serialize(), + alice_kit_.GetSecretKey()->Serialize(), 1e6); + he_kit_mgm_->InitDstKit("bob", bob_kit_.GetPublicKey()->Serialize()); + } + void TearDown() override {} + + protected: + std::unique_ptr he_kit_mgm_; + + heu_phe::HeKit alice_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_alice_kit_ = heu_matrix::HeKit(alice_kit_); + heu_phe::HeKit bob_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_bob_kit_ = heu_matrix::HeKit(bob_kit_); +}; + +TEST_P(PheMergeYTest, Works) { + auto param = GetParam(); + + std::string json_content = R"JSON( +{ + "name": "test_node", + "op": "PHE_2P_MERGE_Y", + "attr_values": { + "decrypted_y_col_name": { + "s": "decrypted_y", + }, + "crypted_y_col_name": { + "s": "crypted_y", + }, + "score_col_name": { + "s": "score" + }, + }, + "op_version": "0.0.1", +} +)JSON"; + NodeDef node_def; + JsonToPb(json_content, &node_def); + { + AttrValue link_func_value; + link_func_value.set_s(param.link_func); + node_def.mutable_attr_values()->insert({"link_function", link_func_value}); + } + { + AttrValue scale_value; + scale_value.set_d(param.yhat_scale); + node_def.mutable_attr_values()->insert({"yhat_scale", scale_value}); + } + { + AttrValue exp_iters_value; + exp_iters_value.set_i32(param.exp_iters); + node_def.mutable_attr_values()->insert({"exp_iters", exp_iters_value}); + } + + auto compute_encoder = he_kit_mgm_->GetEncoder(he::kFeatureScale * + he_kit_mgm_->GetEncodeScale()); + auto base_encoder = he_kit_mgm_->GetEncoder(he_kit_mgm_->GetEncodeScale()); + + // build input&output schema + auto expect_decrypted_schema = + arrow::schema({arrow::field("decrypted_y", arrow::binary())}); + auto expect_crypted_schema = + arrow::schema({arrow::field("crypted_y", arrow::binary())}); + auto expect_output_schema = + arrow::schema({arrow::field("score", arrow::float64())}); + + // create kernel + auto mock_node = std::make_shared(std::move(node_def)); + ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 2); + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + + // check input schema + ASSERT_EQ(kernel->GetInputsNum(), mock_node->GetOpDef()->inputs_size()); + const auto& input_schema_list = kernel->GetAllInputSchema(); + ASSERT_EQ(input_schema_list.size(), kernel->GetInputsNum()); + ASSERT_TRUE(input_schema_list[0]->Equals(expect_decrypted_schema)); + ASSERT_TRUE(input_schema_list[1]->Equals(expect_crypted_schema)); + // check output schema + auto output_schema = kernel->GetOutputSchema(); + ASSERT_TRUE(output_schema->Equals(expect_output_schema)); + + auto alice_y = test::GenRawMatrix(2, 1, 1); + auto bob_y = test::GenRawMatrix(2, 1, 2); + // expect output + Double::Matrix expect_y = alice_y + bob_y; + + // build input + auto p_alice_y = test::EncodeMatrix(alice_y, compute_encoder.get()); + auto p_bob_y = test::EncodeMatrix(bob_y, compute_encoder.get()); + auto c_bob_y = m_alice_kit_.GetEncryptor()->Encrypt(p_bob_y); + + ComputeContext compute_ctx; + compute_ctx.other_party_ids = {"bob"}; + compute_ctx.self_id = "alice"; + compute_ctx.he_kit_mgm = he_kit_mgm_.get(); + + // build input record_batch + // decrypted data + std::shared_ptr alice_p_y_array; + arrow::BinaryBuilder alice_p_y_builder; + auto p_alice_y_buf = p_alice_y.Serialize(); + SERVING_CHECK_ARROW_STATUS(alice_p_y_builder.Append( + p_alice_y_buf.data(), p_alice_y_buf.size())); + SERVING_CHECK_ARROW_STATUS(alice_p_y_builder.Finish(&alice_p_y_array)); + + compute_ctx.inputs.emplace_back( + std::vector>{ + MakeRecordBatch(expect_decrypted_schema, 1, {alice_p_y_array})}); + + // crypted data + std::shared_ptr bob_c_y_array; + arrow::BinaryBuilder bob_c_y_builder; + auto c_bob_y_buf = c_bob_y.Serialize(); + SERVING_CHECK_ARROW_STATUS( + bob_c_y_builder.Append(c_bob_y_buf.data(), c_bob_y_buf.size())); + SERVING_CHECK_ARROW_STATUS(bob_c_y_builder.Finish(&bob_c_y_array)); + + compute_ctx.inputs.emplace_back( + std::vector>{ + MakeRecordBatch(expect_crypted_schema, 1, {bob_c_y_array})}); + + yacl::ElapsedTimer timer; + + kernel->Compute(&compute_ctx); + + std::cout << "---compute time: " << timer.CountMs() << "\n"; + + // check output + ASSERT_TRUE(compute_ctx.output); + ASSERT_TRUE(compute_ctx.output->schema()->Equals(output_schema)); + ASSERT_EQ(compute_ctx.output->num_rows(), 2); + + std::cout << "compute result: " << compute_ctx.output->ToString() + << std::endl; + + std::shared_ptr expect_score_array; + arrow::DoubleBuilder expect_score_builder; + for (int i = 0; i < expect_y.rows(); ++i) { + SERVING_CHECK_ARROW_STATUS(expect_score_builder.Append( + ApplyLinkFunc(expect_y(i, 0), ParseLinkFuncType(param.link_func), + param.exp_iters) * + param.yhat_scale)); + } + SERVING_CHECK_ARROW_STATUS(expect_score_builder.Finish(&expect_score_array)); + + std::cout << "expect result: " << expect_score_array->ToString() << std::endl; + + double epsilon = 1E-13; + ASSERT_TRUE(compute_ctx.output->column(0)->ApproxEquals( + expect_score_array, arrow::EqualOptions::Defaults().atol(epsilon))); +} + +INSTANTIATE_TEST_SUITE_P( + PheMergeYTestSuite, PheMergeYTest, + ::testing::Values( + Param{"LF_EXP", 1.0}, Param{"LF_EXP_TAYLOR", 1.0, 4}, + Param{"LF_RECIPROCAL", 1.1}, Param{"LF_IDENTITY", 1.2}, + Param{"LF_SIGMOID_RAW", 1.3}, Param{"LF_SIGMOID_MM1", 1.4}, + Param{"LF_SIGMOID_MM3", 1.5}, Param{"LF_SIGMOID_GA", 1.6}, + Param{"LF_SIGMOID_T1", 1.7}, Param{"LF_SIGMOID_T3", 1.8}, + Param{"LF_SIGMOID_T5", 1.9}, Param{"LF_SIGMOID_T7", 1.01}, + Param{"LF_SIGMOID_T9", 1.02}, Param{"LF_SIGMOID_LS7", 1.03}, + Param{"LF_SIGMOID_SEG3", 1.04}, Param{"LF_SIGMOID_SEG5", 1.05}, + Param{"LF_SIGMOID_DF", 1.06}, Param{"LF_SIGMOID_SR", 1.07}, + Param{"LF_SIGMOID_SEGLS", 1.08})); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_reduce.cc b/secretflow_serving/ops/phe_linear/phe_2p_reduce.cc new file mode 100644 index 0000000..868345c --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_reduce.cc @@ -0,0 +1,157 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_reduce.h" + +#include "secretflow_serving/ops/node_def_util.h" +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/util/arrow_helper.h" + +namespace secretflow::serving::op::phe_2p { + +namespace { + +template +::yacl::Buffer AddRand(heu_matrix::Evaluator* evaluator, std::string_view y_buf, + std::string_view rand_buf) { + auto y_matrix = heu_matrix::CMatrix::LoadFrom(y_buf); + T rand_num; + rand_num.Deserialize(rand_buf); + + if constexpr (std::is_same_v) { + heu_matrix::CMatrix rand_matrix(y_matrix.rows(), 1); + for (int i = 0; i < y_matrix.rows(); ++i) { + rand_matrix(i, 0) = rand_num; + } + return evaluator->Add(y_matrix, rand_matrix).Serialize(); + } else { + heu_matrix::PMatrix rand_matrix(y_matrix.rows(), 1); + for (int i = 0; i < y_matrix.rows(); ++i) { + rand_matrix(i, 0) = rand_num; + } + return evaluator->Add(y_matrix, rand_matrix).Serialize(); + } +} + +} // namespace + +PheReduce::PheReduce(OpKernelOptions opts) : OpKernel(std::move(opts)) { + // feature name + partial_y_col_name_ = + GetNodeAttr(opts_.node_def, "partial_y_col_name"); + rand_number_col_name_ = + GetNodeAttr(opts_.node_def, "rand_number_col_name"); + select_peer_crypted_ = + GetNodeAttr(opts_.node_def, "select_crypted_for_peer"); + + BuildInputSchema(); + BuildOutputSchema(); +} + +void PheReduce::DoCompute(ComputeContext* ctx) { + SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->inputs.front().size() == 2, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE(ctx->other_party_ids.size() == 1, + errors::ErrorCode::LOGIC_ERROR); + SERVING_ENFORCE_EQ(ctx->requester_id, ctx->self_id); + + auto remote_party_id = *(ctx->other_party_ids.begin()); + + std::shared_ptr peer_record_batch; + std::shared_ptr self_record_batch; + for (const auto& r : ctx->inputs.front()) { + auto p_array = r->column(2); + auto p = std::static_pointer_cast(p_array)->Value(0); + if (p == remote_party_id) { + peer_record_batch = r; + } else { + self_record_batch = r; + } + } + SERVING_ENFORCE(peer_record_batch, errors::ErrorCode::UNEXPECTED_ERROR); + SERVING_ENFORCE(self_record_batch, errors::ErrorCode::UNEXPECTED_ERROR); + + auto build_array = [](const ::yacl::Buffer& buf, + std::shared_ptr* array) { + arrow::BinaryBuilder builder; + SERVING_CHECK_ARROW_STATUS(builder.Append(buf.data(), buf.size())); + SERVING_CHECK_ARROW_STATUS(builder.Finish(array)); + }; + + std::shared_ptr array; + if (select_peer_crypted_) { + const auto& evaluator = + ctx->he_kit_mgm->GetDstMatrixEvaluator(remote_party_id); + auto y_buf = std::static_pointer_cast( + self_record_batch->column(0)) + ->Value(0); + auto rand_buf = std::static_pointer_cast( + peer_record_batch->column(1)) + ->Value(0); + auto ye_matrix_buf = + AddRand(evaluator.get(), y_buf, rand_buf); + build_array(ye_matrix_buf, &array); + } else { + const auto& evaluator = ctx->he_kit_mgm->GetLocalMatrixEvaluator(); + auto y_buf = std::static_pointer_cast( + peer_record_batch->column(0)) + ->Value(0); + auto rand_buf = std::static_pointer_cast( + self_record_batch->column(1)) + ->Value(0); + auto ye_matrix_buf = + AddRand(evaluator.get(), y_buf, rand_buf); + build_array(ye_matrix_buf, &array); + } + + ctx->output = MakeRecordBatch(output_schema_, 1, {array}); +} + +void PheReduce::BuildInputSchema() { + // build input schema + input_schema_list_.emplace_back( + arrow::schema({arrow::field(partial_y_col_name_, arrow::binary()), + arrow::field(rand_number_col_name_, arrow::binary()), + arrow::field("party", arrow::utf8())})); +} + +void PheReduce::BuildOutputSchema() { + // build output schema + output_schema_ = + arrow::schema({arrow::field(partial_y_col_name_, arrow::binary())}); +} + +REGISTER_OP_KERNEL(PHE_2P_REDUCE, PheReduce) +REGISTER_OP(PHE_2P_REDUCE, "0.0.1", + "Two-party computation operator. Select data encrypted by either " + "our side or the peer party " + "according to the configuration.") + .Mergeable() + .StringAttr("partial_y_col_name", + "The name of the partial_y column in the input and output", + false, false) + .StringAttr("rand_number_col_name", + "The name of the rand number column in the input and output", + false, false) + .BoolAttr("select_crypted_for_peer", + "If `True`, select the data can be decrypted by peer, " + "including self calculated partial_y and peer's rand, " + "otherwise select selfs.", + false, false) + .Input("compute results", "The compute results from both self and peer's") + .Output("selected results", "The selected data"); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_reduce.h b/secretflow_serving/ops/phe_linear/phe_2p_reduce.h new file mode 100644 index 0000000..523a90d --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_reduce.h @@ -0,0 +1,38 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "secretflow_serving/ops/op_kernel.h" + +namespace secretflow::serving::op::phe_2p { + +class PheReduce : public OpKernel { + public: + explicit PheReduce(OpKernelOptions opts); + + void DoCompute(ComputeContext* ctx) override; + + protected: + void BuildInputSchema() override; + + void BuildOutputSchema() override; + + private: + std::string partial_y_col_name_; + std::string rand_number_col_name_; + bool select_peer_crypted_; +}; + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/phe_2p_reduce_test.cc b/secretflow_serving/ops/phe_linear/phe_2p_reduce_test.cc new file mode 100644 index 0000000..68f0757 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/phe_2p_reduce_test.cc @@ -0,0 +1,200 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/phe_2p_reduce.h" + +#include "arrow/ipc/api.h" +#include "gtest/gtest.h" + +#include "secretflow_serving/ops/op_factory.h" +#include "secretflow_serving/ops/op_kernel_factory.h" +#include "secretflow_serving/ops/phe_linear/test_utils.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/utils.h" + +namespace secretflow::serving::op::phe_2p { + +struct Param { + bool select_crypted_for_peer; +}; + +class PheReduceTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + he_kit_mgm_ = std::make_unique(); + he_kit_mgm_->InitLocalKit(alice_kit_.GetPublicKey()->Serialize(), + alice_kit_.GetSecretKey()->Serialize(), 1e6); + he_kit_mgm_->InitDstKit("bob", bob_kit_.GetPublicKey()->Serialize()); + } + void TearDown() override {} + + protected: + std::unique_ptr he_kit_mgm_; + + heu_phe::HeKit alice_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_alice_kit_ = heu_matrix::HeKit(alice_kit_); + heu_phe::HeKit bob_kit_ = + heu_phe::HeKit(heu_phe::SchemaType::ZPaillier, 2048); + heu_matrix::HeKit m_bob_kit_ = heu_matrix::HeKit(bob_kit_); +}; + +TEST_P(PheReduceTest, Works) { + auto param = GetParam(); + + std::string json_content = R"JSON( +{ + "name": "test_node", + "op": "PHE_2P_REDUCE", + "attr_values": { + "partial_y_col_name": { + "s": "partial_y", + }, + "rand_number_col_name": { + "s": "rand" + } + }, + "op_version": "0.0.1", +} +)JSON"; + NodeDef node_def; + JsonToPb(json_content, &node_def); + { + AttrValue select_peer_crypted_value; + select_peer_crypted_value.set_b(param.select_crypted_for_peer); + node_def.mutable_attr_values()->insert( + {"select_crypted_for_peer", select_peer_crypted_value}); + } + + // build input&output schema + auto expect_input_schema = + arrow::schema({arrow::field("partial_y", arrow::binary()), + arrow::field("rand", arrow::binary()), + arrow::field("party", arrow::utf8())}); + auto expect_output_schema = + arrow::schema({arrow::field("partial_y", arrow::binary())}); + + // create kernel + auto mock_node = std::make_shared(std::move(node_def)); + ASSERT_EQ(mock_node->GetOpDef()->inputs_size(), 1); + OpKernelOptions opts{mock_node->node_def(), mock_node->GetOpDef()}; + auto kernel = OpKernelFactory::GetInstance()->Create(std::move(opts)); + + // check input schema + ASSERT_EQ(kernel->GetInputsNum(), mock_node->GetOpDef()->inputs_size()); + const auto& input_schema_list = kernel->GetAllInputSchema(); + ASSERT_EQ(input_schema_list.size(), kernel->GetInputsNum()); + ASSERT_TRUE(input_schema_list[0]->Equals(expect_input_schema)); + // check output schema + auto output_schema = kernel->GetOutputSchema(); + ASSERT_TRUE(output_schema->Equals(expect_output_schema)); + + // build input + uint16_t alice_rand = 11; + uint16_t bob_rand = 22; + auto alice_y = test::GenRawMatrix(2, 1, 1); + auto bob_y = test::GenRawMatrix(2, 1, 2); + + auto compute_encoder = he_kit_mgm_->GetEncoder(he::kFeatureScale * + he_kit_mgm_->GetEncodeScale()); + auto base_encoder = he_kit_mgm_->GetEncoder(he_kit_mgm_->GetEncodeScale()); + + alice_y.array() -= alice_rand; + bob_y.array() -= bob_rand; + auto p_alice_y = test::EncodeMatrix(alice_y, compute_encoder.get()); + auto p_bob_y = test::EncodeMatrix(bob_y, compute_encoder.get()); + auto c_alice_y = m_bob_kit_.GetEncryptor()->Encrypt(p_alice_y); + auto c_bob_y = m_alice_kit_.GetEncryptor()->Encrypt(p_bob_y); + auto p_alice_rand = compute_encoder->Encode(alice_rand); + // auto c_alice_rand = + // alice_kit_.GetEncryptor()->Encrypt(compute_encoder->Encode(alice_rand)); + auto c_bob_rand = + bob_kit_.GetEncryptor()->Encrypt(compute_encoder->Encode(bob_rand)); + + ComputeContext compute_ctx; + compute_ctx.other_party_ids = {"bob"}; + compute_ctx.self_id = "alice"; + compute_ctx.requester_id = "alice"; + compute_ctx.he_kit_mgm = he_kit_mgm_.get(); + + // build input record_batch + // decrypted data + std::shared_ptr alice_y_array, bob_y_array, alice_rand_array, + bob_rand_array, alice_party_array, bob_party_array; + { + auto build_y_array_func = [](const heu_matrix::CMatrix& c_m, + std::shared_ptr* array) { + arrow::BinaryBuilder builder; + auto buf = c_m.Serialize(); + SERVING_CHECK_ARROW_STATUS( + builder.Append(buf.data(), buf.size())); + SERVING_CHECK_ARROW_STATUS(builder.Finish(array)); + }; + + auto build_rand_array_func = [](const yacl::Buffer& buf, + std::shared_ptr* array) { + arrow::BinaryBuilder builder; + SERVING_CHECK_ARROW_STATUS( + builder.Append(buf.data(), buf.size())); + SERVING_CHECK_ARROW_STATUS(builder.Finish(array)); + }; + + build_y_array_func(c_alice_y, &alice_y_array); + build_y_array_func(c_bob_y, &bob_y_array); + build_rand_array_func(p_alice_rand.Serialize(), &alice_rand_array); + build_rand_array_func(c_bob_rand.Serialize(), &bob_rand_array); + + using arrow::ipc::internal::json::ArrayFromJSON; + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), "[\"alice\"]"), + alice_party_array); + SERVING_GET_ARROW_RESULT(ArrayFromJSON(arrow::utf8(), "[\"bob\"]"), + bob_party_array); + } + + compute_ctx.inputs.emplace_back( + std::vector>{ + MakeRecordBatch(expect_input_schema, 1, + {alice_y_array, alice_rand_array, alice_party_array}), + MakeRecordBatch(expect_input_schema, 1, + {bob_y_array, bob_rand_array, bob_party_array})}); + + kernel->Compute(&compute_ctx); + + // check output + ASSERT_TRUE(compute_ctx.output); + ASSERT_TRUE(compute_ctx.output->schema()->Equals(output_schema)); + ASSERT_EQ(compute_ctx.output->num_rows(), 1); + + auto partial_y_array = compute_ctx.output->GetColumnByName("partial_y"); + auto partial_y_matrix = heu_matrix::CMatrix::LoadFrom( + std::static_pointer_cast(partial_y_array)->Value(0)); + ASSERT_EQ(partial_y_matrix.rows(), 2); + + if (param.select_crypted_for_peer) { + for (int i = 0; i < partial_y_matrix.rows(); ++i) { + auto expect = bob_kit_.GetEvaluator()->Add(c_alice_y(i, 0), c_bob_rand); + ASSERT_EQ(partial_y_matrix(i, 0), expect); + } + } else { + for (int i = 0; i < partial_y_matrix.rows(); ++i) { + auto expect = alice_kit_.GetEvaluator()->Add(c_bob_y(i, 0), p_alice_rand); + ASSERT_EQ(partial_y_matrix(i, 0), expect); + } + } +} + +INSTANTIATE_TEST_SUITE_P(PheReduceTestSuite, PheReduceTest, + ::testing::Values(Param{true}, Param{false})); + +} // namespace secretflow::serving::op::phe_2p diff --git a/secretflow_serving/ops/phe_linear/test_utils.cc b/secretflow_serving/ops/phe_linear/test_utils.cc new file mode 100644 index 0000000..197dba2 --- /dev/null +++ b/secretflow_serving/ops/phe_linear/test_utils.cc @@ -0,0 +1,45 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/ops/phe_linear/test_utils.h" + +namespace secretflow::serving::test { + +Double::Matrix GenRawMatrix(int rows, int cols, int64_t start) { + Double::Matrix m(rows, cols); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + m(i, j) = start++; + } + } + return m; +} + +heu_matrix::PMatrix EncodeMatrix(const Double::Matrix& m, + heu_phe::PlainEncoder* encoder) { + int64_t ndim = 2; + if (m.cols() == 1) { + ndim = 1; + } + heu_matrix::PMatrix plain_matrix(m.rows(), m.cols(), ndim); + for (int i = 0; i < plain_matrix.rows(); ++i) { + for (int j = 0; j < plain_matrix.cols(); ++j) { + plain_matrix(i, j) = encoder->Encode(m(i, j)); + } + } + + return plain_matrix; +} + +} // namespace secretflow::serving::test diff --git a/secretflow_serving/ops/phe_linear/test_utils.h b/secretflow_serving/ops/phe_linear/test_utils.h new file mode 100644 index 0000000..2057acb --- /dev/null +++ b/secretflow_serving/ops/phe_linear/test_utils.h @@ -0,0 +1,27 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "secretflow_serving/core/types.h" +#include "secretflow_serving/util/he_mgm.h" + +namespace secretflow::serving::test { + +Double::Matrix GenRawMatrix(int rows, int cols, int64_t start = 0); + +heu_matrix::PMatrix EncodeMatrix(const Double::Matrix& m, + heu_phe::PlainEncoder* encoder); + +} // namespace secretflow::serving::test diff --git a/secretflow_serving/protos/graph.proto b/secretflow_serving/protos/graph.proto index d924ca2..bc3a6eb 100644 --- a/secretflow_serving/protos/graph.proto +++ b/secretflow_serving/protos/graph.proto @@ -30,6 +30,11 @@ enum DispatchType { DP_ANYONE = 2; // Dispatch specified participant. DP_SPECIFIED = 3; + // Dispatch self. + DP_SELF = 4; + + // For 2-parties, Dispatch peer participant. + DP_PEER = 12; } // The runtime config of the execution. @@ -103,6 +108,10 @@ message GraphDef { repeated NodeDef node_list = 2; repeated ExecutionDef execution_list = 3; + + HeConfig he_config = 4; + + int32 party_num = 10; } // The view of a graph is used to display the structure of the graph, containing @@ -114,4 +123,26 @@ message GraphView { repeated NodeView node_list = 2; repeated ExecutionDef execution_list = 3; + + HeInfo he_info = 4; + + int32 party_num = 10; +} + +// The config for HE compute. +message HeConfig { + // Serialized public key bytes + bytes pk_buf = 1; + // Serialized secret key bytes + bytes sk_buf = 2; + // Encode scale for data + int64 encode_scale = 3; +} + +// The public info for HE compute. +message HeInfo { + // Serialized public key bytes + bytes pk_buf = 1; + // Encode scale for data + int64 encode_scale = 3; } diff --git a/secretflow_serving/protos/op.proto b/secretflow_serving/protos/op.proto index 0a137ae..67b8a17 100644 --- a/secretflow_serving/protos/op.proto +++ b/secretflow_serving/protos/op.proto @@ -38,6 +38,7 @@ message OpTag { bool mergeable = 2; // The operator needs to be executed in session. + // TODO: not supported yet. bool session_run = 3; // Whether this op has variable input argument. default `false`. diff --git a/secretflow_serving/server/execution_core.cc b/secretflow_serving/server/execution_core.cc index 689cb3f..ab49a5b 100644 --- a/secretflow_serving/server/execution_core.cc +++ b/secretflow_serving/server/execution_core.cc @@ -113,6 +113,7 @@ void ExecutionCore::Execute(const apis::ExecuteRequest* request, // executable run Executable::Task task; task.id = request->task().execution_id(); + task.requester_id = request->requester_id(); task.features = features; for (const auto& n : request->task().nodes()) { SERVING_ENFORCE_EQ(n.ios_size(), 1, diff --git a/secretflow_serving/server/kuscia/BUILD.bazel b/secretflow_serving/server/kuscia/BUILD.bazel index 676a790..e4c97d5 100644 --- a/secretflow_serving/server/kuscia/BUILD.bazel +++ b/secretflow_serving/server/kuscia/BUILD.bazel @@ -13,7 +13,6 @@ # limitations under the License. load("//bazel:serving.bzl", "serving_cc_library", "serving_cc_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) diff --git a/secretflow_serving/server/kuscia/config_parser.cc b/secretflow_serving/server/kuscia/config_parser.cc index 1af474d..bca3ee9 100644 --- a/secretflow_serving/server/kuscia/config_parser.cc +++ b/secretflow_serving/server/kuscia/config_parser.cc @@ -30,6 +30,22 @@ namespace secretflow::serving::kuscia { +namespace { +const char* kSpiTlsConfigKey = "spi_tls_config"; + +const char* kServingIdKey = "serving_id"; + +const char* kClusterDefKey = "cluster_def"; + +const char* kInputConfigKey = "input_config"; + +const char* kAllocatedPortKey = "allocated_ports"; + +const char* kOssMetaKey = "oss_meta"; + +const char* kHttpSourceMetaKey = "http_source_meta"; +} // namespace + namespace kusica_proto = ::kuscia::proto::api::v1alpha1::appconfig; KusciaConfigParser::KusciaConfigParser(const std::string& config_file) { @@ -48,19 +64,19 @@ KusciaConfigParser::KusciaConfigParser(const std::string& config_file) { raw_config_str); // get services id - SERVING_ENFORCE(doc["serving_id"].IsString(), + SERVING_ENFORCE(doc[kServingIdKey].IsString(), errors::ErrorCode::INVALID_ARGUMENT); - service_id_ = {doc["serving_id"].GetString(), - doc["serving_id"].GetStringLength()}; + service_id_ = {doc[kServingIdKey].GetString(), + doc[kServingIdKey].GetStringLength()}; int self_party_idx = 0; std::string self_party_id; { // parse cluster_def - SERVING_ENFORCE(doc["cluster_def"].IsString(), + SERVING_ENFORCE(doc[kClusterDefKey].IsString(), errors::ErrorCode::INVALID_ARGUMENT); - std::string cluster_def_str = {doc["cluster_def"].GetString(), - doc["cluster_def"].GetStringLength()}; + std::string cluster_def_str = {doc[kClusterDefKey].GetString(), + doc[kClusterDefKey].GetStringLength()}; kusica_proto::ClusterDefine cluster_def; JsonToPb(cluster_def_str, &cluster_def); @@ -91,10 +107,10 @@ KusciaConfigParser::KusciaConfigParser(const std::string& config_file) { { // parse input config - SERVING_ENFORCE(doc["input_config"].IsString(), + SERVING_ENFORCE(doc[kInputConfigKey].IsString(), errors::ErrorCode::INVALID_ARGUMENT); - std::string input_config_str = {doc["input_config"].GetString(), - doc["input_config"].GetStringLength()}; + std::string input_config_str = {doc[kInputConfigKey].GetString(), + doc[kInputConfigKey].GetStringLength()}; KusciaServingConfig serving_config; JsonToPb(input_config_str, &serving_config); @@ -112,11 +128,11 @@ KusciaConfigParser::KusciaConfigParser(const std::string& config_file) { { // parse allocated_ports - SERVING_ENFORCE(doc["allocated_ports"].IsString(), + SERVING_ENFORCE(doc[kAllocatedPortKey].IsString(), errors::ErrorCode::INVALID_ARGUMENT); std::string allocated_ports_str = { - doc["allocated_ports"].GetString(), - doc["allocated_ports"].GetStringLength()}; + doc[kAllocatedPortKey].GetString(), + doc[kAllocatedPortKey].GetStringLength()}; kusica_proto::AllocatedPorts allocated_ports; JsonToPb(allocated_ports_str, &allocated_ports); @@ -144,15 +160,24 @@ KusciaConfigParser::KusciaConfigParser(const std::string& config_file) { // load oss config if (model_config_.source_type() == SourceType::ST_OSS) { - SERVING_ENFORCE(doc["oss_meta"].IsString(), + SERVING_ENFORCE(doc.HasMember(kOssMetaKey), + errors::ErrorCode::INVALID_ARGUMENT); + SERVING_ENFORCE(doc[kOssMetaKey].IsString(), errors::ErrorCode::INVALID_ARGUMENT); - std::string oss_meta_str = {doc["oss_meta"].GetString(), - doc["oss_meta"].GetStringLength()}; - if (!oss_meta_str.empty()) { - OSSSourceMeta oss_meta; - JsonToPb(oss_meta_str, model_config_.mutable_oss_source_meta()); - } else { - SPDLOG_WARN("oss meta is null"); + std::string oss_meta_str = {doc[kOssMetaKey].GetString(), + doc[kOssMetaKey].GetStringLength()}; + SERVING_ENFORCE(!oss_meta_str.empty(), errors::ErrorCode::INVALID_ARGUMENT, + "get empty `oss_meta`"); + JsonToPb(oss_meta_str, model_config_.mutable_oss_source_meta()); + } else if (model_config_.source_type() == SourceType::ST_HTTP) { + if (doc.HasMember(kHttpSourceMetaKey)) { + SERVING_ENFORCE(doc[kHttpSourceMetaKey].IsString(), + errors::ErrorCode::INVALID_ARGUMENT); + std::string meta_str = {doc[kHttpSourceMetaKey].GetString(), + doc[kHttpSourceMetaKey].GetStringLength()}; + if (!meta_str.empty()) { + JsonToPb(meta_str, model_config_.mutable_http_source_meta()); + } } } @@ -190,20 +215,15 @@ KusciaConfigParser::KusciaConfigParser(const std::string& config_file) { // fill spi tls config if (feature_config_.has_value() && feature_config_->has_http_opts()) { auto* http_opts = feature_config_->mutable_http_opts(); - const char* KSpiTlsConfigKey = "spi_tls_config"; - if (doc.HasMember(KSpiTlsConfigKey)) { - SERVING_ENFORCE(doc[KSpiTlsConfigKey].IsString(), + if (doc.HasMember(kSpiTlsConfigKey)) { + SERVING_ENFORCE(doc[kSpiTlsConfigKey].IsString(), errors::ErrorCode::INVALID_ARGUMENT); std::string spi_tls_config_str = { - doc[KSpiTlsConfigKey].GetString(), - doc[KSpiTlsConfigKey].GetStringLength()}; + doc[kSpiTlsConfigKey].GetString(), + doc[kSpiTlsConfigKey].GetStringLength()}; if (!spi_tls_config_str.empty()) { SPDLOG_INFO("spi tls config: {}", spi_tls_config_str); - TlsConfig spi_tls_config; - JsonToPb(spi_tls_config_str, &spi_tls_config); - http_opts->mutable_tls_config()->CopyFrom(spi_tls_config); - } else { - SPDLOG_WARN("spi tls config is empty"); + JsonToPb(spi_tls_config_str, http_opts->mutable_tls_config()); } } } diff --git a/secretflow_serving/server/kuscia/config_parser_test.cc b/secretflow_serving/server/kuscia/config_parser_test.cc index b6e0b0a..f10f3aa 100644 --- a/secretflow_serving/server/kuscia/config_parser_test.cc +++ b/secretflow_serving/server/kuscia/config_parser_test.cc @@ -32,11 +32,12 @@ TEST_F(KusciaConfigParserTest, Works) { tmpfile.save(1 + R"JSON( { "serving_id": "kd-1", - "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"examples/alice/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"alice_ep\"}},\"channel_desc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"examples/bob/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"bob_ep\"}},\"channel_desc\":{\"protocol\":\"http\"}}}}", + "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"examples/alice/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"alice_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"examples/bob/glm-test.tar.gz\",\"sourceType\":\"ST_FILE\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"bob_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}}}}", "cluster_def": "{\"parties\":[{\"name\":\"alice\",\"role\":\"\",\"services\":[{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.alice.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.alice.svc:53510\"]},{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.alice.svc:53511\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.alice.svc\"]}]},{\"name\":\"bob\",\"role\":\"\",\"services\":[{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.bob.svc:53511\"]},{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.bob.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.bob.svc:53510\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.bob.svc\"]}]}],\"selfPartyIdx\":0,\"selfEndpointIdx\":0}", "allocated_ports": "{\"ports\":[{\"name\":\"service\",\"port\":53509,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"communication\",\"port\":53508,\"scope\":\"Cluster\",\"protocol\":\"HTTP\"},{\"name\":\"internal\",\"port\":53510,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"brpc-builtin\",\"port\":53511,\"scope\":\"Domain\",\"protocol\":\"HTTP\"}]}", "oss_meta": "", - "spi_tls_config": "{\"certificate_path\":\"abc\", \"private_key_path\":\"def\",\"ca_file_path\":\"gkh\"}" + "spi_tls_config": "{\"certificatePath\":\"abc\", \"privateKeyPath\":\"def\",\"caFilePath\":\"gkh\"}", + "http_source_meta": "" } )JSON"); @@ -78,12 +79,73 @@ TEST_F(KusciaConfigParserTest, Works) { EXPECT_EQ(53508, server_config.communication_port()); } +TEST_F(KusciaConfigParserTest, OSSMeta) { + butil::TempFile tmpfile; + tmpfile.save(1 + R"JSON( +{ + "serving_id": "kd-1", + "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"examples/alice/glm-test.tar.gz\",\"sourceType\":\"ST_OSS\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"alice_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"examples/bob/glm-test.tar.gz\",\"sourceType\":\"ST_OSS\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"bob_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}}}}", + "cluster_def": "{\"parties\":[{\"name\":\"alice\",\"role\":\"\",\"services\":[{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.alice.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.alice.svc:53510\"]},{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.alice.svc:53511\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.alice.svc\"]}]},{\"name\":\"bob\",\"role\":\"\",\"services\":[{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.bob.svc:53511\"]},{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.bob.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.bob.svc:53510\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.bob.svc\"]}]}],\"selfPartyIdx\":0,\"selfEndpointIdx\":0}", + "allocated_ports": "{\"ports\":[{\"name\":\"service\",\"port\":53509,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"communication\",\"port\":53508,\"scope\":\"Cluster\",\"protocol\":\"HTTP\"},{\"name\":\"internal\",\"port\":53510,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"brpc-builtin\",\"port\":53511,\"scope\":\"Domain\",\"protocol\":\"HTTP\"}]}", + "oss_meta": "{\"accessKey\":\"test_ak\", \"secretKey\":\"test_sk\", \"virtualHosted\":true, \"endpoint\":\"test_endpoint\", \"bucket\":\"test_bucket\"}", + "spi_tls_config": "", + "http_source_meta": "" +} +)JSON"); + + KusciaConfigParser config_parser(tmpfile.fname()); + + auto model_config = config_parser.model_config(); + EXPECT_EQ("glm-test-1", model_config.model_id()); + EXPECT_EQ("/tmp/alice", model_config.base_path()); + EXPECT_EQ(SourceType::ST_OSS, model_config.source_type()); + EXPECT_EQ("test_ak", model_config.oss_source_meta().access_key()); + EXPECT_EQ("test_sk", model_config.oss_source_meta().secret_key()); + EXPECT_TRUE(model_config.oss_source_meta().virtual_hosted()); + EXPECT_EQ("test_endpoint", model_config.oss_source_meta().endpoint()); + EXPECT_EQ("test_bucket", model_config.oss_source_meta().bucket()); + + EXPECT_FALSE(config_parser.feature_config()->http_opts().has_tls_config()); +} + +TEST_F(KusciaConfigParserTest, HttpSourceMeta) { + butil::TempFile tmpfile; + tmpfile.save(1 + R"JSON( +{ + "serving_id": "kd-1", + "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"examples/alice/glm-test.tar.gz\",\"sourceType\":\"ST_HTTP\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"alice_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"examples/bob/glm-test.tar.gz\",\"sourceType\":\"ST_HTTP\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"bob_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}}}}", + "cluster_def": "{\"parties\":[{\"name\":\"alice\",\"role\":\"\",\"services\":[{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.alice.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.alice.svc:53510\"]},{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.alice.svc:53511\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.alice.svc\"]}]},{\"name\":\"bob\",\"role\":\"\",\"services\":[{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.bob.svc:53511\"]},{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.bob.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.bob.svc:53510\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.bob.svc\"]}]}],\"selfPartyIdx\":0,\"selfEndpointIdx\":0}", + "allocated_ports": "{\"ports\":[{\"name\":\"service\",\"port\":53509,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"communication\",\"port\":53508,\"scope\":\"Cluster\",\"protocol\":\"HTTP\"},{\"name\":\"internal\",\"port\":53510,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"brpc-builtin\",\"port\":53511,\"scope\":\"Domain\",\"protocol\":\"HTTP\"}]}", + "oss_meta": "", + "spi_tls_config": "", + "http_source_meta": "{\"connectTimeoutMs\":60000,\"timeoutMs\":120000,\"tlsConfig\":{\"certificatePath\":\"abc\", \"privateKeyPath\":\"def\",\"caFilePath\":\"gkh\"}}" +} +)JSON"); + + KusciaConfigParser config_parser(tmpfile.fname()); + + auto model_config = config_parser.model_config(); + EXPECT_EQ("glm-test-1", model_config.model_id()); + EXPECT_EQ("/tmp/alice", model_config.base_path()); + EXPECT_EQ(SourceType::ST_HTTP, model_config.source_type()); + EXPECT_EQ(60000, model_config.http_source_meta().connect_timeout_ms()); + EXPECT_EQ(120000, model_config.http_source_meta().timeout_ms()); + EXPECT_TRUE(model_config.http_source_meta().has_tls_config()); + EXPECT_EQ("abc", + model_config.http_source_meta().tls_config().certificate_path()); + EXPECT_EQ("def", + model_config.http_source_meta().tls_config().private_key_path()); + EXPECT_EQ("gkh", model_config.http_source_meta().tls_config().ca_file_path()); + + EXPECT_FALSE(config_parser.feature_config()->http_opts().has_tls_config()); +} + TEST_F(KusciaConfigParserTest, DPWorks) { butil::TempFile tmpfile; tmpfile.save(1 + R"JSON( { "serving_id": "kd-1", - "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"alice-1234\",\"sourceType\":\"ST_DP\",\"dpSourceMeta\":{\"dmHost\":\"127.0.0.1:8071\",\"tls_config\":{\"certificatePath\":\"kusciaapi-server.crt\",\"privateKeyPath\":\"kusciaapi-server.key\",\"caFilePath\":\"ca.crt\"}}},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"alice_ep\"}},\"channel_desc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"alice-1234\",\"sourceType\":\"ST_DP\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"bob_ep\"}},\"channel_desc\":{\"protocol\":\"http\"}}}}", + "input_config": "{\"partyConfigs\":{\"alice\":{\"serverConfig\":{\"featureMapping\":{\"v24\":\"x24\",\"v22\":\"x22\",\"v21\":\"x21\",\"v25\":\"x25\",\"v23\":\"x23\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/alice\",\"sourceSha256\":\"3b6a3b76a8d5bbf0e45b83f2d44772a0a6aa9a15bf382cee22cbdc8f59d55522\",\"sourcePath\":\"alice-1234\",\"sourceType\":\"ST_DP\",\"dpSourceMeta\":{\"dmHost\":\"127.0.0.1:8071\",\"tls_config\":{\"certificatePath\":\"kusciaapi-server.crt\",\"privateKeyPath\":\"kusciaapi-server.key\",\"caFilePath\":\"ca.crt\"}}},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"alice_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}},\"bob\":{\"serverConfig\":{\"featureMapping\":{\"v6\":\"x6\",\"v7\":\"x7\",\"v8\":\"x8\",\"v9\":\"x9\",\"v10\":\"x10\"}},\"modelConfig\":{\"modelId\":\"glm-test-1\",\"basePath\":\"/tmp/bob\",\"sourceSha256\":\"330192f3a51f9498dd882478bfe08a06501e2ed4aa2543a0fb586180925eb309\",\"sourcePath\":\"alice-1234\",\"sourceType\":\"ST_DP\"},\"featureSourceConfig\":{\"httpOpts\":{\"endpoint\":\"bob_ep\"}},\"channelDesc\":{\"protocol\":\"http\"}}}}", "cluster_def": "{\"parties\":[{\"name\":\"alice\",\"role\":\"\",\"services\":[{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.alice.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.alice.svc:53510\"]},{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.alice.svc:53511\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.alice.svc\"]}]},{\"name\":\"bob\",\"role\":\"\",\"services\":[{\"portName\":\"brpc-builtin\",\"endpoints\":[\"kd-1-brpc-builtin.bob.svc:53511\"]},{\"portName\":\"service\",\"endpoints\":[\"kd-1-service.bob.svc:53508\"]},{\"portName\":\"internal\",\"endpoints\":[\"kd-1-internal.bob.svc:53510\"]},{\"portName\":\"communication\",\"endpoints\":[\"kd-1-communication.bob.svc\"]}]}],\"selfPartyIdx\":0,\"selfEndpointIdx\":0}", "allocated_ports": "{\"ports\":[{\"name\":\"service\",\"port\":53509,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"communication\",\"port\":53508,\"scope\":\"Cluster\",\"protocol\":\"HTTP\"},{\"name\":\"internal\",\"port\":53510,\"scope\":\"Domain\",\"protocol\":\"HTTP\"},{\"name\":\"brpc-builtin\",\"port\":53511,\"scope\":\"Domain\",\"protocol\":\"HTTP\"}]}", "spi_tls_config": "{\"certificate_path\":\"abc\", \"private_key_path\":\"def\",\"ca_file_path\":\"gkh\"}" @@ -95,10 +157,10 @@ TEST_F(KusciaConfigParserTest, DPWorks) { EXPECT_EQ("alice-1234", model_config.source_path()); EXPECT_EQ(SourceType::ST_DP, model_config.source_type()); - auto dp_source_meta = model_config.dp_source_meta(); + const auto& dp_source_meta = model_config.dp_source_meta(); EXPECT_EQ("127.0.0.1:8071", dp_source_meta.dm_host()); - auto dp_tls_config = dp_source_meta.tls_config(); + const auto& dp_tls_config = dp_source_meta.tls_config(); EXPECT_EQ(dp_tls_config.certificate_path(), "kusciaapi-server.crt"); EXPECT_EQ(dp_tls_config.private_key_path(), "kusciaapi-server.key"); EXPECT_EQ(dp_tls_config.ca_file_path(), "ca.crt"); diff --git a/secretflow_serving/server/main.cc b/secretflow_serving/server/main.cc index c00f2bd..a9f0dec 100644 --- a/secretflow_serving/server/main.cc +++ b/secretflow_serving/server/main.cc @@ -23,6 +23,7 @@ #include "secretflow_serving/server/server.h" #include "secretflow_serving/server/trace/trace.h" #include "secretflow_serving/server/version.h" +#include "secretflow_serving/util/network.h" #include "secretflow_serving/util/utils.h" #include "secretflow_serving/config/serving_config.pb.h" @@ -34,6 +35,9 @@ DEFINE_string(config_mode, "", DEFINE_string(serving_config_file, "", "read an ascii config protobuf from the supplied file name."); +DEFINE_bool(enable_peers_load_balancer, false, + "whether to enable load balancer between parties"); + // logging config DEFINE_string( logging_config_file, "", @@ -53,8 +57,12 @@ DEFINE_string( void InitLogger() { secretflow::serving::LoggingConfig log_config; if (!FLAGS_logging_config_file.empty()) { - secretflow::serving::LoadPbFromJsonFile(FLAGS_logging_config_file, - &log_config); + auto logging_config_str = + secretflow::serving::ReadFileContent(FLAGS_logging_config_file); + if (!secretflow::serving::CheckContentEmpty(logging_config_str)) { + secretflow::serving::JsonToPb( + secretflow::serving::UnescapeJson(logging_config_str), &log_config); + } } secretflow::serving::SetupLogging(log_config); } @@ -62,8 +70,12 @@ void InitLogger() { void InitTracer() { secretflow::serving::TraceConfig trace_log; if (!FLAGS_trace_config_file.empty()) { - secretflow::serving::LoadPbFromJsonFile(FLAGS_trace_config_file, - &trace_log); + auto trace_config_str = + secretflow::serving::ReadFileContent(FLAGS_trace_config_file); + if (!secretflow::serving::CheckContentEmpty(trace_config_str)) { + secretflow::serving::JsonToPb( + secretflow::serving::UnescapeJson(trace_config_str), &trace_log); + } } secretflow::serving::InitTracer(trace_log); } @@ -91,7 +103,6 @@ int main(int argc, char* argv[]) { [&](const std::shared_ptr& o) { op_names.emplace_back(o->name()); }); - SPDLOG_INFO("op list: {}", fmt::join(op_names.begin(), op_names.end(), ", ")); } @@ -121,8 +132,10 @@ int main(int argc, char* argv[]) { server_opts.service_id = serving_conf.id(); } + auto clannels = secretflow::serving::BuildChannelsFromConfig( + server_opts.cluster_config, FLAGS_enable_peers_load_balancer); secretflow::serving::Server server(std::move(server_opts)); - server.Start(); + server.Start(std::move(clannels)); server.WaitForEnd(); } catch (const secretflow::serving::Exception& e) { diff --git a/secretflow_serving/server/prediction_core.cc b/secretflow_serving/server/prediction_core.cc index 6d55283..e757179 100644 --- a/secretflow_serving/server/prediction_core.cc +++ b/secretflow_serving/server/prediction_core.cc @@ -31,12 +31,7 @@ PredictionCore::PredictionCore(Options opts) : opts_(std::move(opts)) { void PredictionCore::Predict(const apis::PredictRequest* request, apis::PredictResponse* response) noexcept { try { - response->mutable_service_spec()->CopyFrom(request->service_spec()); - auto* status = response->mutable_status(); - - CheckArgument(request); - opts_.predictor->Predict(request, response); - status->set_code(errors::ErrorCode::OK); + PredictImpl(request, response); } catch (const Exception& e) { SPDLOG_ERROR("Predict failed, request: {}, code:{}, msg:{}, stack:{}", PbToJsonNoExcept(request), e.code(), e.what(), @@ -51,6 +46,17 @@ void PredictionCore::Predict(const apis::PredictRequest* request, } } +void PredictionCore::PredictImpl(const apis::PredictRequest* request, + apis::PredictResponse* response) { + response->mutable_service_spec()->CopyFrom(request->service_spec()); + auto* status = response->mutable_status(); + + CheckArgument(request); + + opts_.predictor->Predict(request, response); + status->set_code(errors::ErrorCode::OK); +} + void PredictionCore::CheckArgument(const apis::PredictRequest* request) { SERVING_ENFORCE_EQ(request->service_spec().id(), opts_.service_id, "invalid service spec id: {}", diff --git a/secretflow_serving/server/prediction_core.h b/secretflow_serving/server/prediction_core.h index 45264d7..36f46e3 100644 --- a/secretflow_serving/server/prediction_core.h +++ b/secretflow_serving/server/prediction_core.h @@ -39,6 +39,9 @@ class PredictionCore { void Predict(const apis::PredictRequest* request, apis::PredictResponse* response) noexcept; + void PredictImpl(const apis::PredictRequest* request, + apis::PredictResponse* response); + const std::string& GetServiceID() const { return opts_.service_id; } const std::string& GetPartyID() const { return opts_.party_id; } diff --git a/secretflow_serving/server/prediction_service_impl.cc b/secretflow_serving/server/prediction_service_impl.cc index 317232a..cdc7766 100644 --- a/secretflow_serving/server/prediction_service_impl.cc +++ b/secretflow_serving/server/prediction_service_impl.cc @@ -23,19 +23,12 @@ namespace secretflow::serving { -PredictionServiceImpl::PredictionServiceImpl(const std::string& party_id) +PredictionServiceImpl::PredictionServiceImpl( + const std::string& party_id, + const std::shared_ptr& prediction_core) : party_id_(party_id), - stats_({{"handler", "PredictionService"}, {"party_id", party_id_}}), - init_flag_(false) {} - -void PredictionServiceImpl::Init( - const std::shared_ptr& prediction_core) { - SERVING_ENFORCE(prediction_core, errors::ErrorCode::LOGIC_ERROR); - SERVING_ENFORCE(!init_flag_, errors::ErrorCode::LOGIC_ERROR); - - prediction_core_ = prediction_core; - init_flag_ = true; -} + prediction_core_(prediction_core), + stats_({{"handler", "PredictionService"}, {"party_id", party_id_}}) {} void PredictionServiceImpl::Predict( ::google::protobuf::RpcController* controller, @@ -56,14 +49,7 @@ void PredictionServiceImpl::Predict( SPDLOG_DEBUG("predict begin, request: {}", request->ShortDebugString()); yacl::ElapsedTimer timer; - if (!init_flag_) { - response->mutable_service_spec()->CopyFrom(request->service_spec()); - response->mutable_status()->set_code(errors::ErrorCode::NOT_READY); - response->mutable_status()->set_msg( - "prediction service is not ready to serve, please retry later."); - } else { - prediction_core_->Predict(request, response); - } + prediction_core_->Predict(request, response); timer.Pause(); SPDLOG_DEBUG("predict end, time: {}", timer.CountMs()); @@ -95,7 +81,7 @@ void PredictionServiceImpl::RecordMetrics(const apis::PredictRequest& request, } PredictionServiceImpl::Stats::Stats( - std::map labels, + const std::map& labels, const std::shared_ptr<::prometheus::Registry>& registry) : api_request_counter_family( ::prometheus::BuildCounter() diff --git a/secretflow_serving/server/prediction_service_impl.h b/secretflow_serving/server/prediction_service_impl.h index 86cc6ab..0ed847b 100644 --- a/secretflow_serving/server/prediction_service_impl.h +++ b/secretflow_serving/server/prediction_service_impl.h @@ -29,9 +29,9 @@ namespace secretflow::serving { // 预测 - 服务入口 class PredictionServiceImpl : public apis::PredictionService { public: - explicit PredictionServiceImpl(const std::string& party_id); - - void Init(const std::shared_ptr& prediction_core); + explicit PredictionServiceImpl( + const std::string& party_id, + const std::shared_ptr& prediction_core); void Predict(::google::protobuf::RpcController* controller, const apis::PredictRequest* request, @@ -48,7 +48,7 @@ class PredictionServiceImpl : public apis::PredictionService { ::prometheus::Family<::prometheus::Counter>& predict_counter_family; ::prometheus::Counter& predict_counter; - explicit Stats(std::map labels, + explicit Stats(const std::map& labels, const std::shared_ptr<::prometheus::Registry>& registry = metrics::GetDefaultRegistry()); }; @@ -60,11 +60,9 @@ class PredictionServiceImpl : public apis::PredictionService { private: const std::string& party_id_; - Stats stats_; - std::shared_ptr prediction_core_; - std::atomic init_flag_; + Stats stats_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/server/server.cc b/secretflow_serving/server/server.cc index 4e979c6..5813ad9 100644 --- a/secretflow_serving/server/server.cc +++ b/secretflow_serving/server/server.cc @@ -21,7 +21,6 @@ #include "secretflow_serving/framework/model_loader.h" #include "secretflow_serving/ops/graph.h" #include "secretflow_serving/server/execution_service_impl.h" -#include "secretflow_serving/server/health.h" #include "secretflow_serving/server/metrics/default_metrics_registry.h" #include "secretflow_serving/server/metrics/metrics_service.h" #include "secretflow_serving/server/prediction_service_impl.h" @@ -34,16 +33,10 @@ #include "secretflow_serving/apis/metrics.pb.h" #include "secretflow_serving/apis/prediction_service.pb.h" -DEFINE_bool(enable_peers_load_balancer, false, - "whether to enable load balancer between parties"); - namespace secretflow::serving { namespace { -const int32_t kPeerConnectTimeoutMs = 500; -const int32_t kPeerRpcTimeoutMs = 2000; - void SetServerTLSOpts(const TlsConfig& tls_config, brpc::ServerSSLOptions* server_ssl_opts) { server_ssl_opts->default_cert.certificate = tls_config.certificate_path(); @@ -61,6 +54,7 @@ Server::Server(Options opts) : opts_(std::move(opts)) { errors::ErrorCode::INVALID_ARGUMENT, "too few parties params for cluster config, get: {}", opts_.cluster_config.parties_size()); + hr_ = std::make_unique(); } Server::~Server() { @@ -71,10 +65,15 @@ Server::~Server() { communication_server_.Join(); metrics_server_.Join(); + hr_ = nullptr; model_service_ = nullptr; } -void Server::Start() { +void Server::Start( + std::shared_ptr< + std::map>> + channels, + google::protobuf::Service* additional_service) { const auto& self_party_id = opts_.cluster_config.self_id(); // get model package @@ -86,36 +85,11 @@ void Server::Start() { SERVING_ENFORCE(!host.empty(), errors::ErrorCode::INVALID_ARGUMENT, "get empty host."); - // build channels std::vector cluster_ids; - auto channels = std::make_shared(); - for (const auto& party : opts_.cluster_config.parties()) { - cluster_ids.emplace_back(party.id()); - if (party.id() == self_party_id) { - continue; - } - const auto& channel_desc = opts_.cluster_config.channel_desc(); - channels->emplace( - party.id(), - CreateBrpcChannel( - party.id(), party.address(), channel_desc.protocol(), - FLAGS_enable_peers_load_balancer, - channel_desc.rpc_timeout_ms() > 0 ? channel_desc.rpc_timeout_ms() - : kPeerRpcTimeoutMs, - channel_desc.connect_timeout_ms() > 0 - ? channel_desc.connect_timeout_ms() - : kPeerConnectTimeoutMs, - channel_desc.has_tls_config() ? &channel_desc.tls_config() - : nullptr, - channel_desc.has_retry_policy_config() - ? &channel_desc.retry_policy_config() - : nullptr)); - } - - auto com_address = - fmt::format("{}:{}", host, opts_.server_config.communication_port()); - auto service_address = - fmt::format("{}:{}", host, opts_.server_config.service_port()); + std::transform(opts_.cluster_config.parties().begin(), + opts_.cluster_config.parties().end(), + std::back_inserter(cluster_ids), + [](const auto& p) { return p.id(); }); // load model package auto loader = std::make_unique(); @@ -126,7 +100,7 @@ void Server::Start() { // build execution core std::vector executors; for (const auto& execution : graph.GetExecutions()) { - executors.emplace_back(Executor(execution)); + executors.emplace_back(execution, self_party_id, cluster_ids); } ExecutionCore::Options exec_opts; exec_opts.id = opts_.service_id; @@ -197,12 +171,22 @@ void Server::Start() { } // start commnication server + auto com_address = + fmt::format("{}:{}", host, opts_.server_config.communication_port()); std::map model_info_map = { {opts_.service_id, model_info_collector.GetSelfModelInfo()}}; model_service_ = std::make_unique(std::move(model_info_map), self_party_id); { brpc::ServerOptions com_server_options; + com_server_options.health_reporter = hr_.get(); + if (opts_.server_config.brpc_builtin_service_port() > 0) { + com_server_options.has_builtin_services = true; + com_server_options.internal_port = + opts_.server_config.brpc_builtin_service_port(); + SPDLOG_INFO("brpc built-in service port: {}", + com_server_options.internal_port); + } if (opts_.server_config.worker_num() > 0) { com_server_options.num_threads = opts_.server_config.worker_num(); } @@ -221,6 +205,13 @@ void Server::Start() { SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, "fail to add model service into com brpc server."); } + if (additional_service != nullptr) { + if (communication_server_.AddService( + additional_service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to add additional service into com brpc server."); + } + } // start services server communication_server_.set_version(SERVING_VERSION_STRING); if (communication_server_.Start(com_address.c_str(), &com_server_options) != @@ -229,61 +220,21 @@ void Server::Start() { "fail to start communication brpc server at {}", com_address); } - SPDLOG_INFO("begin communication server listen at {}, ", com_address); - } - // start service server - brpc::ServerOptions server_options; - server_options.max_concurrency = opts_.server_config.max_concurrency(); - if (opts_.server_config.worker_num() > 0) { - server_options.num_threads = opts_.server_config.worker_num(); - } - if (opts_.server_config.brpc_builtin_service_port() > 0) { - server_options.has_builtin_services = true; - server_options.internal_port = - opts_.server_config.brpc_builtin_service_port(); - SPDLOG_INFO("brpc built-in service port: {}", server_options.internal_port); - } - if (opts_.server_config.has_tls_config()) { - SetServerTLSOpts(opts_.server_config.tls_config(), - server_options.mutable_ssl_options()); - } - health::ServingHealthReporter hr; - server_options.health_reporter = &hr; - // FIXME: - // kuscia场景需要在服务启动后,使服务状态可用,此时才能挂载路由。但服务需要完成 - // exchange model info 才能 ready - hr.SetStatusCode(200); - - if (service_server_.AddService(model_service_.get(), - brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "fail to add model service into brpc server."); - } - auto* prediction_service = new PredictionServiceImpl(self_party_id); - if (service_server_.AddService(prediction_service, - brpc::SERVER_OWNS_SERVICE) != 0) { - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "fail to add prediction service into brpc server."); - } - service_server_.set_version(SERVING_VERSION_STRING); - if (service_server_.Start(service_address.c_str(), &server_options) != 0) { - SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, - "fail to start service brpc server at {}", service_address); + // FIXME: + // kuscia场景需要在服务启动后,使服务状态可用,此时才能挂载路由。 + // 但服务需要完成 exchange model info 才能 ready + hr_->SetStatusCode(200); } - SPDLOG_INFO("begin service server listen at {}, ", service_address); - // exchange model info SPDLOG_INFO("start exchange model_info"); - model_info_collector.DoCollect(); auto specific_map = model_info_collector.GetSpecificMap(); - SPDLOG_INFO("end exchange model_info"); - // build prediction core, let prediction service begin to serve + // build prediction core Predictor::Options predictor_opts; predictor_opts.party_id = self_party_id; predictor_opts.channels = channels; @@ -297,9 +248,44 @@ void Server::Start() { prediction_core_opts.party_id = self_party_id; prediction_core_opts.cluster_ids = std::move(cluster_ids); prediction_core_opts.predictor = predictor; - auto prediction_core = + prediction_core_ = std::make_shared(std::move(prediction_core_opts)); - prediction_service->Init(prediction_core); + + if (opts_.server_config.service_port() > 0) { + auto service_address = + fmt::format("{}:{}", host, opts_.server_config.service_port()); + + // start service server + brpc::ServerOptions server_options; + server_options.max_concurrency = opts_.server_config.max_concurrency(); + if (opts_.server_config.worker_num() > 0) { + server_options.num_threads = opts_.server_config.worker_num(); + } + if (opts_.server_config.has_tls_config()) { + SetServerTLSOpts(opts_.server_config.tls_config(), + server_options.mutable_ssl_options()); + } + if (service_server_.AddService(model_service_.get(), + brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to add model service into brpc server."); + } + auto* prediction_service = + new PredictionServiceImpl(self_party_id, prediction_core_); + if (service_server_.AddService(prediction_service, + brpc::SERVER_OWNS_SERVICE) != 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to add prediction service into brpc server."); + } + service_server_.set_version(SERVING_VERSION_STRING); + if (service_server_.Start(service_address.c_str(), &server_options) != 0) { + SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR, + "fail to start service brpc server at {}", service_address); + } + SPDLOG_INFO("begin service server listen at {}, ", service_address); + } else { + SPDLOG_INFO("service port is 0, prediction service is disabled."); + } } void Server::WaitForEnd() { diff --git a/secretflow_serving/server/server.h b/secretflow_serving/server/server.h index 32893df..b674734 100644 --- a/secretflow_serving/server/server.h +++ b/secretflow_serving/server/server.h @@ -20,7 +20,9 @@ #include "brpc/server.h" +#include "secretflow_serving/server/health.h" #include "secretflow_serving/server/model_service_impl.h" +#include "secretflow_serving/server/prediction_core.h" #include "secretflow_serving/config/cluster_config.pb.h" #include "secretflow_serving/config/feature_config.pb.h" @@ -45,11 +47,18 @@ class Server { explicit Server(Options opts); ~Server(); - void Start(); + void Start(std::shared_ptr>> + channels, + google::protobuf::Service* additional_service = nullptr); // This will block the current thread until termination is successful. void WaitForEnd(); + std::shared_ptr GetPredictionCore() { + return prediction_core_; + } + private: const Options opts_; @@ -57,7 +66,11 @@ class Server { brpc::Server communication_server_; brpc::Server metrics_server_; + std::unique_ptr hr_; + std::unique_ptr model_service_; + + std::shared_ptr prediction_core_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/tools/inferencer/BUILD.bazel b/secretflow_serving/tools/inferencer/BUILD.bazel new file mode 100644 index 0000000..66ac95e --- /dev/null +++ b/secretflow_serving/tools/inferencer/BUILD.bazel @@ -0,0 +1,83 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile") +load("//bazel:serving.bzl", "serving_cc_binary", "serving_cc_library") + +package(default_visibility = ["//visibility:public"]) + +proto_library( + name = "inference_service_proto", + srcs = ["inference_service.proto"], + deps = [ + "//secretflow_serving/apis:status_proto", + ], +) + +cc_proto_library( + name = "inference_service_cc_proto", + deps = [":inference_service_proto"], +) + +proto_library( + name = "config_proto", + srcs = ["config.proto"], +) + +cc_proto_library( + name = "config_cc_proto", + deps = [":config_proto"], +) + +serving_cc_library( + name = "control_service_impl", + srcs = ["control_service_impl.cc"], + hdrs = ["control_service_impl.h"], + deps = [ + ":inference_service_cc_proto", + "//secretflow_serving/apis:error_code_cc_proto", + "@com_github_brpc_brpc//:brpc", + ], +) + +serving_cc_library( + name = "inference_executor", + srcs = ["inference_executor.cc"], + hdrs = ["inference_executor.h"], + deps = [ + ":config_cc_proto", + ":control_service_impl", + "//secretflow_serving/server", + ], +) + +serving_cc_binary( + name = "inferencer", + srcs = ["main.cc"], + deps = [ + ":inference_executor", + "//secretflow_serving/core:exception", + "//secretflow_serving/core:logging", + "@com_google_absl//absl/debugging:failure_signal_handler", + "@com_google_absl//absl/debugging:symbolize", + ], +) + +python_proto_compile( + name = "inference_config_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "../../../", + protos = [":config_proto"], +) diff --git a/secretflow_serving/tools/inferencer/README.md b/secretflow_serving/tools/inferencer/README.md new file mode 100644 index 0000000..24010e1 --- /dev/null +++ b/secretflow_serving/tools/inferencer/README.md @@ -0,0 +1,29 @@ +# Usage + +## C++ + +### Alice + +```bash +bazel-bin/secretflow_serving/tools/inferencer/inferencer --serving_config_file=secretflow_serving/tools/inferencer/example/alice/serving.config --inference_config_file=secretflow_serving/tools/inferencer/example/alice/inference.config +``` + +### Bob + +```bash +bazel-bin/secretflow_serving/tools/inferencer/inferencer --serving_config_file=secretflow_serving/tools/inferencer/example/bob/serving.config --inference_config_file=secretflow_serving/tools/inferencer/example/bob/inference.config +``` + +## Python + +```python + import importlib + + with importlib.resources.path('secretflow_serving.tools.inferencer', 'inferencer') as tool_path: + + # dump serving config file + + # dump inference config file + + # run inferencer... +``` diff --git a/secretflow_serving/tools/inferencer/config.proto b/secretflow_serving/tools/inferencer/config.proto new file mode 100644 index 0000000..ef51c56 --- /dev/null +++ b/secretflow_serving/tools/inferencer/config.proto @@ -0,0 +1,37 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package secretflow.serving.tools; + +message InferenceConfig { + // Inference requester id, inference result file is only output in requester + // party + string requester_id = 1; + + // The file path of inference. + string result_file_path = 2; + + // Optional. Additional columns must exist in the requester's feature input + // file and are added to the output file. + repeated string additional_col_names = 3; + + // The name of inference score column name in result file. + string score_col_name = 4; + + // Optional. This determines the size of each request batch. + int32 block_size = 11; +} diff --git a/secretflow_serving/tools/inferencer/control_service_impl.cc b/secretflow_serving/tools/inferencer/control_service_impl.cc new file mode 100644 index 0000000..bb27a12 --- /dev/null +++ b/secretflow_serving/tools/inferencer/control_service_impl.cc @@ -0,0 +1,104 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/tools/inferencer/control_service_impl.h" + +#include "brpc/closure_guard.h" +#include "brpc/controller.h" +#include "fmt/format.h" +#include "spdlog/spdlog.h" + +#include "secretflow_serving/apis/error_code.pb.h" + +namespace secretflow::serving::tools { + +InferenceControlServiceImpl::InferenceControlServiceImpl( + const std::string& requester_id, int32_t row_num) + : requester_id_(requester_id), row_num_(row_num) {} + +void InferenceControlServiceImpl::ReadyToServe() { + std::lock_guard lock(mux_); + ready_flag_ = true; +} + +void InferenceControlServiceImpl::Push( + ::google::protobuf::RpcController* controller, + const ControlRequest* request, ControlResponse* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + + std::lock_guard lock(mux_); + + if (!ready_flag_) { + response->mutable_status()->set_code(errors::ErrorCode::UNEXPECTED_ERROR); + response->mutable_status()->set_msg("waiting serving server ready"); + SPDLOG_WARN(response->status().msg()); + return; + } + + if (request->party_id() != requester_id_) { + response->mutable_status()->set_code(errors::ErrorCode::UNEXPECTED_ERROR); + response->mutable_status()->set_msg( + fmt::format("recv control msg from {}, but requseter should be {}", + request->party_id(), requester_id_)); + SPDLOG_ERROR(response->status().msg()); + stop_flag_ = true; + return; + } + response->mutable_status()->set_code(errors::ErrorCode::OK); + + switch (request->type()) { + case CM_INIT: { + SPDLOG_INFO("recv init msg from {}", request->party_id()); + init_flag_ = true; + response->mutable_init_msg()->set_row_num(row_num_); + break; + } + case CM_STOP: { + SPDLOG_INFO("recv stop msg from {}", request->party_id()); + stop_flag_ = true; + break; + } + case CM_KEEPALIVE: { + heart_beat_count_++; + break; + } + default: { + response->mutable_status()->set_code(errors::ErrorCode::UNEXPECTED_ERROR); + response->mutable_status()->set_msg(fmt::format( + "unsupport msg type: {}", static_cast(request->type()))); + SPDLOG_ERROR("deal request from {} failed, msg: {}", request->party_id(), + response->status().msg()); + stop_flag_ = true; + break; + } + } +} + +bool InferenceControlServiceImpl::stop_flag() { + std::lock_guard lock(mux_); + return stop_flag_; +} + +bool InferenceControlServiceImpl::init_flag() { + std::lock_guard lock(mux_); + return init_flag_; +} + +uint64_t InferenceControlServiceImpl::heart_beat_count() { + std::lock_guard lock(mux_); + return heart_beat_count_; +} + +} // namespace secretflow::serving::tools diff --git a/secretflow_serving/tools/inferencer/control_service_impl.h b/secretflow_serving/tools/inferencer/control_service_impl.h new file mode 100644 index 0000000..50322c6 --- /dev/null +++ b/secretflow_serving/tools/inferencer/control_service_impl.h @@ -0,0 +1,56 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "secretflow_serving/tools/inferencer/inference_service.pb.h" + +namespace secretflow::serving::tools { + +// 预测 - 服务入口 +class InferenceControlServiceImpl : public InferenceControlService { + public: + InferenceControlServiceImpl(const std::string& requester_id, int32_t row_num); + + void Push(::google::protobuf::RpcController* controller, + const ControlRequest* request, ControlResponse* response, + ::google::protobuf::Closure* done) override; + + void ReadyToServe(); + + [[nodiscard]] bool stop_flag(); + + [[nodiscard]] bool init_flag(); + + uint64_t heart_beat_count(); + + private: + std::mutex mux_; + + const std::string requester_id_; + + const int32_t row_num_; + + bool ready_flag_ = false; + + bool init_flag_ = false; + + bool stop_flag_ = false; + + uint64_t heart_beat_count_ = 0; +}; + +} // namespace secretflow::serving::tools diff --git a/secretflow_serving/tools/inferencer/example/alice/inference.config b/secretflow_serving/tools/inferencer/example/alice/inference.config new file mode 100644 index 0000000..04418ce --- /dev/null +++ b/secretflow_serving/tools/inferencer/example/alice/inference.config @@ -0,0 +1,8 @@ +{ + "requester_id": "alice", + "result_file_path": "./tmp/alice/score.csv", + "additional_col_names": [ + "y" + ], + "score_col_name": "score" +} diff --git a/secretflow_serving/tools/inferencer/example/alice/serving.config b/secretflow_serving/tools/inferencer/example/alice/serving.config new file mode 100644 index 0000000..6f16208 --- /dev/null +++ b/secretflow_serving/tools/inferencer/example/alice/serving.config @@ -0,0 +1,43 @@ +{ + "id": "test_service_id", + "serverConf": { + "host": "0.0.0.0", + "communicationPort": "8110", + }, + "modelConf": { + "modelId": "glm-test-1", + "basePath": "./tmp/alice", + "sourcePath": ".ci/test_data/bin_onehot_glm_alice_no_feature/alice/s_model.tar.gz", + "sourceType": "ST_FILE" + }, + "clusterConf": { + "selfId": "alice", + "parties": [ + { + "id": "alice", + "address": "0.0.0.0:8110" + }, + { + "id": "bob", + "address": "0.0.0.0:8111" + } + ], + "channel_desc": { + "protocol": "http", + "retryPolicyConfig": { + "retryCustom": "true", + "retryAggressive": "true", + "maxRetryCount": "3", + "fixedBackoffConfig": { + "intervalMs": "100" + }, + } + } + }, + "featureSourceConf": { + "streamingOpts": { + "file_path": ".ci/test_data/bin_onehot_glm_alice_no_feature/alice/alice.csv", + "id_name": "id", + } + } +} diff --git a/secretflow_serving/tools/inferencer/example/bob/inference.config b/secretflow_serving/tools/inferencer/example/bob/inference.config new file mode 100644 index 0000000..8c26014 --- /dev/null +++ b/secretflow_serving/tools/inferencer/example/bob/inference.config @@ -0,0 +1,4 @@ +{ + "requester_id": "alice", + "score_col_name": "score" +} diff --git a/secretflow_serving/tools/inferencer/example/bob/serving.config b/secretflow_serving/tools/inferencer/example/bob/serving.config new file mode 100644 index 0000000..c365fd4 --- /dev/null +++ b/secretflow_serving/tools/inferencer/example/bob/serving.config @@ -0,0 +1,43 @@ +{ + "id": "test_service_id", + "serverConf": { + "host": "0.0.0.0", + "communicationPort": "8111", + }, + "modelConf": { + "modelId": "glm-test-1", + "basePath": "./tmp/bob", + "sourcePath": ".ci/test_data/bin_onehot_glm_alice_no_feature/bob/s_model.tar.gz", + "sourceType": "ST_FILE" + }, + "clusterConf": { + "selfId": "bob", + "parties": [ + { + "id": "alice", + "address": "0.0.0.0:8110" + }, + { + "id": "bob", + "address": "0.0.0.0:8111" + } + ], + "channel_desc": { + "protocol": "http", + "retryPolicyConfig": { + "retryCustom": "true", + "retryAggressive": "true", + "maxRetryCount": "3", + "fixedBackoffConfig": { + "intervalMs": "100" + }, + } + } + }, + "featureSourceConf": { + "streamingOpts": { + "file_path": ".ci/test_data/bin_onehot_glm_alice_no_feature/bob/bob.csv", + "id_name": "id", + } + } +} diff --git a/secretflow_serving/tools/inferencer/inference_executor.cc b/secretflow_serving/tools/inferencer/inference_executor.cc new file mode 100644 index 0000000..71ff188 --- /dev/null +++ b/secretflow_serving/tools/inferencer/inference_executor.cc @@ -0,0 +1,351 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/tools/inferencer/inference_executor.h" + +#include + +#include "spdlog/spdlog.h" + +#include "secretflow_serving/tools/inferencer/control_service_impl.h" +#include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/csv_util.h" +#include "secretflow_serving/util/network.h" +#include "secretflow_serving/util/utils.h" + +#include "secretflow_serving/tools/inferencer/inference_service.pb.h" + +namespace secretflow::serving::tools { + +namespace { + +const size_t kInitCheckRetryCount = 60; +const size_t kInitCheckRetryIntervalMs = 1000; + +const size_t kSendRetryCount = 30; +const size_t kSendRetryIntervalMs = 1000; + +const size_t kHeartBeatCheckIntervalMs = 3000; +const size_t kHeartBeatCheckFailedNum = 10; + +} // namespace + +namespace { +int32_t GetCsvFileRowNum(const std::string& file_path) { + std::ifstream file(file_path); + SERVING_ENFORCE(file.is_open(), errors::ErrorCode::IO_ERROR, "open {} failed", + file_path); + + std::string line; + // skip header + std::getline(file, line); + + int32_t row_num = 0; + while (std::getline(file, line)) { + row_num++; + } + file.close(); + + return row_num; +} +} // namespace + +InferenceExecutor::InferenceExecutor(Options opts) : opts_(std::move(opts)) { + // TODO: check config valiable + SERVING_ENFORCE(opts_.serving_conf.has_feature_source_conf(), + errors::INVALID_ARGUMENT); + SERVING_ENFORCE(opts_.serving_conf.feature_source_conf().has_streaming_opts(), + errors::INVALID_ARGUMENT); + + row_num_ = GetCsvFileRowNum( + opts_.serving_conf.feature_source_conf().streaming_opts().file_path()); + + channels_ = BuildChannelsFromConfig(opts_.serving_conf.cluster_conf()); + + // begin services + Server::Options server_opts{ + .service_id = opts_.serving_conf.id(), + .server_config = opts_.serving_conf.server_conf(), + .cluster_config = opts_.serving_conf.cluster_conf(), + .model_config = opts_.serving_conf.model_conf(), + .feature_source_config = opts_.serving_conf.feature_source_conf()}; + server_ = std::make_unique(std::move(server_opts)); + + cntl_svc_ = std::make_unique( + opts_.inference_conf.requester_id(), row_num_); + server_->Start(channels_, cntl_svc_.get()); + + cntl_svc_->ReadyToServe(); + + prediction_core_ = server_->GetPredictionCore(); +} + +InferenceExecutor::~InferenceExecutor() { + StopKeepAlive(); + prediction_core_ = nullptr; +} + +void InferenceExecutor::Run() { + try { + OnRun(); + } catch (...) { + stop_flag_ = true; + throw; + } +} + +void InferenceExecutor::OnRun() { + SPDLOG_INFO("begin batch predict."); + + if (opts_.inference_conf.requester_id() != + opts_.serving_conf.cluster_conf().self_id()) { + SPDLOG_INFO("begin waiting requester init...."); + + RetryRunner runner(kInitCheckRetryCount, kInitCheckRetryIntervalMs); + SERVING_ENFORCE(runner.Run([this]() { + return cntl_svc_->init_flag() || cntl_svc_->stop_flag(); + }), + serving::errors::UNEXPECTED_ERROR, + "waiting init msg from {} timeout", + opts_.inference_conf.requester_id()); + + SPDLOG_INFO("init finish."); + + if (cntl_svc_->stop_flag()) { + SPDLOG_INFO("stop flag is true, just stop."); + return; + } + WaitForEnd(); + return; + } + + // init other party + RetryRunner runner(kSendRetryCount, kSendRetryIntervalMs); + std::vector row_num_list; + for (const auto& [p, c] : *channels_) { + ControlResponse res; + SERVING_ENFORCE(runner.Run( + [this](const std::string& party_id, + ::google::protobuf::RpcChannel* channel, + ControlResponse* response) { + return SendMsg(party_id, channel, + ControlMessageType::CM_INIT, response); + }, + p, c.get(), &res), + serving::errors::UNEXPECTED_ERROR, + "send init msg to {} failed.", p); + row_num_list.emplace_back(res.init_msg().row_num()); + } + SERVING_ENFORCE(std::all_of(row_num_list.begin(), row_num_list.end(), + [this](auto e) { return e == row_num_; }), + errors::UNEXPECTED_ERROR, + "The number of input file lines of different participants " + "does not match. {} vs {}", + row_num_, + fmt::join(row_num_list.begin(), row_num_list.end(), ",")); + + // start keepalive + keepalive_thread_ = std::thread(&InferenceExecutor::KeepAlive, this); + + // build read colums types + std::unordered_map> col_types{ + {opts_.serving_conf.feature_source_conf().streaming_opts().id_name(), + arrow::utf8()}}; + for (const auto& c : opts_.inference_conf.additional_col_names()) { + col_types.emplace(c, arrow::utf8()); + } + + // build output schema + std::shared_ptr output_schema; + { + std::vector> fields{ + arrow::field(opts_.inference_conf.score_col_name(), arrow::float64())}; + std::transform( + col_types.begin(), col_types.end(), std::back_inserter(fields), + [](const auto& p) { return arrow::field(p.first, arrow::utf8()); }); + output_schema = arrow::schema(std::move(fields)); + } + + // csv reader + arrow::csv::ReadOptions read_opts = arrow::csv::ReadOptions::Defaults(); + if (opts_.inference_conf.block_size() > 0) { + read_opts.block_size = opts_.inference_conf.block_size(); + } + std::shared_ptr csv_reader = + csv::BuildStreamingReader( + opts_.serving_conf.feature_source_conf().streaming_opts().file_path(), + std::move(col_types), read_opts); + + // build result writer + std::shared_ptr csv_writer = + csv::BuildeStreamingWriter(opts_.inference_conf.result_file_path(), + output_schema); + + // begin batch predict + apis::PredictRequest pred_request; + pred_request.mutable_service_spec()->set_id(opts_.serving_conf.id()); + auto* fs_params = pred_request.mutable_fs_params(); + fs_params->insert({opts_.serving_conf.cluster_conf().self_id(), {}}); + for (const auto& [p, c] : *channels_) { + fs_params->insert({p, {}}); + } + int32_t idx = 0; + std::shared_ptr batch; + while (true) { + ++idx; + SERVING_CHECK_ARROW_STATUS(csv_reader->ReadNext(&batch)); + if (!batch) { + // read finish + break; + } + SERVING_ENFORCE_GT( + batch->num_rows(), 0, + "may be because `block_size` is configured too small: {}", + opts_.inference_conf.block_size()); + + auto id_array = std::static_pointer_cast( + batch->GetColumnByName(opts_.serving_conf.feature_source_conf() + .streaming_opts() + .id_name())); + + // build fs_params + for (auto& [party, param] : *fs_params) { + param.clear_query_datas(); + param.set_query_context(std::to_string(idx)); + for (int64_t i = 0; i < id_array->length(); ++i) { + auto item = id_array->Value(i); + param.add_query_datas(item.data(), item.length()); + } + } + + // batch predict + apis::PredictResponse pred_response; + prediction_core_->PredictImpl(&pred_request, &pred_response); + + // build score array + std::shared_ptr score_array; + arrow::DoubleBuilder builder; + for (const auto& r : pred_response.results()) { + // TODO: only support one score in result + SERVING_CHECK_ARROW_STATUS(builder.Append(r.scores(0).value())); + } + SERVING_CHECK_ARROW_STATUS(builder.Finish(&score_array)); + + // write result + std::vector> arrays; + arrays.reserve(output_schema->num_fields()); + for (const auto& f : output_schema->fields()) { + if (f->name() == opts_.inference_conf.score_col_name()) { + arrays.emplace_back(std::move(score_array)); + continue; + } + arrays.emplace_back(batch->GetColumnByName(f->name())); + } + auto result_batch = MakeRecordBatch( + output_schema, pred_response.results_size(), std::move(arrays)); + SERVING_CHECK_ARROW_STATUS(csv_writer->WriteRecordBatch(*result_batch)); + } + SERVING_CHECK_ARROW_STATUS(csv_writer->Close()); + + // send end msg + for (const auto& [p, c] : *channels_) { + SERVING_ENFORCE(runner.Run( + [this](const std::string& party_id, + ::google::protobuf::RpcChannel* channel) { + ControlResponse res; + return SendMsg(party_id, channel, + ControlMessageType::CM_STOP, &res); + }, + p, c.get()), + serving::errors::UNEXPECTED_ERROR, + "send stop msg to {} failed.", p); + } + + SPDLOG_INFO("batch predict finish."); +} + +void InferenceExecutor::WaitForEnd() { + SPDLOG_INFO("waiting for end"); + uint64_t last_heart_beat_count = 0; + size_t failed_count = 0; + while (failed_count <= kHeartBeatCheckFailedNum && !cntl_svc_->stop_flag()) { + std::this_thread::sleep_for( + std::chrono::milliseconds(kHeartBeatCheckIntervalMs)); + + auto now = cntl_svc_->heart_beat_count(); + if (now == last_heart_beat_count) { + ++failed_count; + } else { + last_heart_beat_count = now; + failed_count = 0; + } + } + if (failed_count > kHeartBeatCheckFailedNum) { + SPDLOG_ERROR("heart beat timeout..."); + } + SPDLOG_INFO("finish wait."); +} + +void InferenceExecutor::KeepAlive() { + while (true) { + if (stop_flag_) { + SPDLOG_INFO("stop send keepalive msg."); + return; + } + for (const auto& [p, c] : *channels_) { + ControlResponse res; + SendMsg(p, c.get(), ControlMessageType::CM_KEEPALIVE, &res); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } +} + +void InferenceExecutor::StopKeepAlive() { + stop_flag_ = true; + if (keepalive_thread_.joinable()) { + keepalive_thread_.join(); + } +} + +bool InferenceExecutor::SendMsg(const std::string& target_id, + ::google::protobuf::RpcChannel* channel, + ControlMessageType type, + ControlResponse* response) { + brpc::Controller cntl; + // close brpc retry. + cntl.set_max_retry(0); + + ControlRequest request; + request.set_party_id(opts_.serving_conf.cluster_conf().self_id()); + request.set_type(type); + + InferenceControlService_Stub stub(channel); + stub.Push(&cntl, &request, response, nullptr); + if (cntl.Failed()) { + SPDLOG_WARN( + "call ({}) init control failed, msg:{}, may need " + "retry", + target_id, cntl.ErrorText()); + return false; + } else if (!CheckStatusOk(response->status())) { + SPDLOG_WARN( + "call ({}) init control msg failed, msg:{}, may need " + "retry", + target_id, response->status().msg()); + return false; + } + return true; +} + +} // namespace secretflow::serving::tools diff --git a/secretflow_serving/tools/inferencer/inference_executor.h b/secretflow_serving/tools/inferencer/inference_executor.h new file mode 100644 index 0000000..5488ba8 --- /dev/null +++ b/secretflow_serving/tools/inferencer/inference_executor.h @@ -0,0 +1,76 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "secretflow_serving/server/server.h" +#include "secretflow_serving/tools/inferencer/control_service_impl.h" + +#include "secretflow_serving/config/serving_config.pb.h" +#include "secretflow_serving/tools/inferencer/config.pb.h" + +namespace secretflow::serving::tools { + +class InferenceExecutor { + public: + struct Options { + ServingConfig serving_conf; + + InferenceConfig inference_conf; + }; + + public: + explicit InferenceExecutor(Options opts); + ~InferenceExecutor(); + + void Run(); + + private: + void OnRun(); + + void WaitForEnd(); + + void KeepAlive(); + + void StopKeepAlive(); + + bool SendMsg(const std::string& target_id, + ::google::protobuf::RpcChannel* channel, ControlMessageType type, + ControlResponse* response); + + private: + const Options opts_; + + std::unique_ptr server_; + + std::unique_ptr cntl_svc_; + + std::shared_ptr< + std::map>> + channels_; + + std::shared_ptr prediction_core_; + + int32_t row_num_; + + std::thread keepalive_thread_; + + std::atomic stop_flag_ = false; +}; + +} // namespace secretflow::serving::tools diff --git a/secretflow_serving/tools/inferencer/inference_service.proto b/secretflow_serving/tools/inferencer/inference_service.proto new file mode 100644 index 0000000..fc3d53e --- /dev/null +++ b/secretflow_serving/tools/inferencer/inference_service.proto @@ -0,0 +1,54 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package secretflow.serving.tools; + +import "secretflow_serving/apis/status.proto"; + +option cc_generic_services = true; + +enum ControlMessageType { + INVALID_MESSAGE_TYPE = 0; + + CM_INIT = 1; + + CM_STOP = 2; + + CM_KEEPALIVE = 3; +} + +message InitMsg { + int32 row_num = 1; +} + +message ControlRequest { + string party_id = 1; + + ControlMessageType type = 2; +} + +message ControlResponse { + apis.Status status = 1; + + oneof kind { + InitMsg init_msg = 3; + } +} + +service InferenceControlService { + rpc Push(ControlRequest) returns (ControlResponse); +} diff --git a/secretflow_serving/tools/inferencer/main.cc b/secretflow_serving/tools/inferencer/main.cc new file mode 100644 index 0000000..4185eeb --- /dev/null +++ b/secretflow_serving/tools/inferencer/main.cc @@ -0,0 +1,104 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/debugging/failure_signal_handler.h" +#include "absl/debugging/symbolize.h" +#include "gflags/gflags.h" + +#include "secretflow_serving/core/exception.h" +#include "secretflow_serving/core/logging.h" +#include "secretflow_serving/server/version.h" +#include "secretflow_serving/tools/inferencer/inference_executor.h" +#include "secretflow_serving/util/utils.h" + +#include "secretflow_serving/config/serving_config.pb.h" + +DEFINE_string(serving_config_file, "", + "read an ascii config protobuf from the supplied file name."); +DEFINE_string(inference_config_file, "", ""); + +// logging config +DEFINE_string( + logging_config_file, "", + "read an ascii LoggingConfig protobuf from the supplied file name."); + +DECLARE_bool(inferencer_mode); + +#define STRING_EMPTY_VALIDATOR(str_config) \ + if (str_config.empty()) { \ + SERVING_THROW(secretflow::serving::errors::ErrorCode::INVALID_ARGUMENT, \ + "{} get empty value", #str_config); \ + } + +void InitLogger() { + secretflow::serving::LoggingConfig log_config; + if (!FLAGS_logging_config_file.empty()) { + auto logging_config_str = + secretflow::serving::ReadFileContent(FLAGS_logging_config_file); + if (!secretflow::serving::CheckContentEmpty(logging_config_str)) { + secretflow::serving::JsonToPb( + secretflow::serving::UnescapeJson(logging_config_str), &log_config); + } + } + secretflow::serving::SetupLogging(log_config); +} + +int main(int argc, char* argv[]) { + // Initialize the symbolizer to get a human-readable stack trace + absl::InitializeSymbolizer(argv[0]); + + gflags::SetVersionString(SERVING_VERSION_STRING); + gflags::AllowCommandLineReparsing(); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + FLAGS_inferencer_mode = true; + + try { + InitLogger(); + + SPDLOG_INFO("version: {}", SERVING_VERSION_STRING); + + STRING_EMPTY_VALIDATOR(FLAGS_serving_config_file); + STRING_EMPTY_VALIDATOR(FLAGS_inference_config_file); + + // init server options + secretflow::serving::ServingConfig serving_conf; + secretflow::serving::LoadPbFromJsonFile(FLAGS_serving_config_file, + &serving_conf); + + secretflow::serving::tools::InferenceConfig inference_conf; + secretflow::serving::LoadPbFromJsonFile(FLAGS_inference_config_file, + &inference_conf); + + secretflow::serving::tools::InferenceExecutor server( + secretflow::serving::tools::InferenceExecutor::Options{ + .serving_conf = std::move(serving_conf), + .inference_conf = std::move(inference_conf)}); + server.Run(); + } catch (const secretflow::serving::Exception& e) { + std::string msg = + fmt::format("inferencer run failed, code: {}, msg: {}, stack: {}", + e.code(), e.what(), e.stack_trace()); + SPDLOG_ERROR(msg); + std::cerr << msg << std::endl; + return -1; + } catch (const std::exception& e) { + std::string msg = fmt::format("inferencer run failed, msg:{}", e.what()); + SPDLOG_ERROR(msg); + std::cerr << msg << std::endl; + return -1; + } + + return 0; +} diff --git a/secretflow_serving/tools/simple_feature_service/simple_feature_service.h b/secretflow_serving/tools/simple_feature_service/simple_feature_service.h index 007e8e0..6f9dd30 100644 --- a/secretflow_serving/tools/simple_feature_service/simple_feature_service.h +++ b/secretflow_serving/tools/simple_feature_service/simple_feature_service.h @@ -34,7 +34,7 @@ class SimpleBatchFeatureService : public spis::BatchFeatureService { ::google::protobuf::Closure *done) override; private: - CSVExtractor extractor_; + csv::CSVExtractor extractor_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/util/BUILD.bazel b/secretflow_serving/util/BUILD.bazel index 9457715..a0b60bf 100644 --- a/secretflow_serving/util/BUILD.bazel +++ b/secretflow_serving/util/BUILD.bazel @@ -59,15 +59,21 @@ serving_cc_library( ], ) +serving_cc_library( + name = "csv_util", + srcs = ["csv_util.cc"], + hdrs = ["csv_util.h"], + deps = [ + ":arrow_helper", + ], +) + serving_cc_library( name = "csv_extractor", srcs = ["csv_extractor.cc"], hdrs = ["csv_extractor.h"], deps = [ - ":arrow_helper", - ":utils", - "//secretflow_serving/core:exception", - "//secretflow_serving/protos:data_type_cc_proto", + ":csv_util", "//secretflow_serving/spis:batch_feature_service_cc_proto", ], ) @@ -133,6 +139,7 @@ serving_cc_library( hdrs = ["network.h"], deps = [ ":retry_policy", + "//secretflow_serving/config:cluster_config_cc_proto", "//secretflow_serving/config:tls_config_cc_proto", "//secretflow_serving/core:exception", "@com_github_brpc_brpc//:brpc", @@ -155,3 +162,14 @@ serving_cc_test( ":thread_safe_queue", ], ) + +serving_cc_library( + name = "he_mgm", + srcs = ["he_mgm.cc"], + hdrs = ["he_mgm.h"], + deps = [ + "//secretflow_serving/core:exception", + "//secretflow_serving/core:singleton", + "@com_alipay_sf_heu//heu/library/numpy", + ], +) diff --git a/secretflow_serving/util/arrow_helper.cc b/secretflow_serving/util/arrow_helper.cc index 6770556..ba80aca 100644 --- a/secretflow_serving/util/arrow_helper.cc +++ b/secretflow_serving/util/arrow_helper.cc @@ -19,6 +19,7 @@ #include #include "arrow/compute/api.h" +#include "arrow/io/api.h" #include "arrow/ipc/api.h" #include "secretflow_serving/core/exception.h" @@ -386,35 +387,6 @@ void CheckReferenceFields(const std::shared_ptr& src, } } -std::shared_ptr ReadCsvFileToTable( - const std::string& path, - const std::shared_ptr& feature_schema) { - // read csv file - std::shared_ptr file; - SERVING_GET_ARROW_RESULT(arrow::io::ReadableFile::Open(path), file); - - arrow::csv::ConvertOptions convert_options; - - for (int i = 0; i < feature_schema->num_fields(); ++i) { - std::shared_ptr field = feature_schema->field(i); - - convert_options.include_columns.push_back(field->name()); - convert_options.column_types[field->name()] = field->type(); - } - - std::shared_ptr csv_reader; - SERVING_GET_ARROW_RESULT( - arrow::csv::TableReader::Make(arrow::io::default_io_context(), file, - arrow::csv::ReadOptions::Defaults(), - arrow::csv::ParseOptions::Defaults(), - convert_options), - csv_reader); - - std::shared_ptr table; - SERVING_GET_ARROW_RESULT(csv_reader->Read(), table); - return table; -} - arrow::Datum GetRowsFilter( const std::shared_ptr& id_column, const std::vector& ids) { @@ -442,19 +414,6 @@ arrow::Datum GetRowsFilter( return filter; } -std::shared_ptr GetIdColumnFromFile( - const std::string& filename, const std::string& id_name) { - std::vector> fields; - fields.push_back(arrow::field(id_name, arrow::utf8())); - auto schema = arrow::schema(fields); - auto table = ReadCsvFileToTable(filename, schema); - auto id_column = table->GetColumnByName(id_name); - SERVING_ENFORCE(id_column, errors::ErrorCode::INVALID_ARGUMENT, - "column: {} is not in csv file: {}", id_name, filename); - - return id_column; -} - std::shared_ptr ExtractRowsFromTable( const std::shared_ptr& table, const arrow::Datum& filter) { arrow::Datum filtered_table; diff --git a/secretflow_serving/util/arrow_helper.h b/secretflow_serving/util/arrow_helper.h index 02e9c57..4541c14 100644 --- a/secretflow_serving/util/arrow_helper.h +++ b/secretflow_serving/util/arrow_helper.h @@ -17,8 +17,6 @@ #include #include "arrow/api.h" -#include "arrow/csv/api.h" -#include "arrow/io/api.h" #include "google/protobuf/repeated_field.h" #include "secretflow_serving/core/exception.h" @@ -103,10 +101,6 @@ void CheckReferenceFields(const std::shared_ptr& src, const std::shared_ptr& dst, const std::string& additional_msg = ""); -std::shared_ptr ReadCsvFileToTable( - const std::string& path, - const std::shared_ptr& feature_schema); - arrow::Datum GetRowsFilter( const std::shared_ptr& id_column, const std::vector& ids); @@ -114,9 +108,6 @@ arrow::Datum GetRowsFilter( std::shared_ptr ExtractRowsFromTable( const std::shared_ptr& table, const arrow::Datum& filter); -std::shared_ptr GetIdColumnFromFile( - const std::string& filename, const std::string& id_name); - std::shared_ptr CastToDoubleArray( const std::shared_ptr& array); diff --git a/secretflow_serving/util/csv_extractor.cc b/secretflow_serving/util/csv_extractor.cc index c19a75a..b4b499a 100644 --- a/secretflow_serving/util/csv_extractor.cc +++ b/secretflow_serving/util/csv_extractor.cc @@ -16,20 +16,21 @@ #include #include -#include #include #include #include +#include #include "arrow/compute/api.h" #include "spdlog/spdlog.h" #include "secretflow_serving/core/exception.h" #include "secretflow_serving/util/arrow_helper.h" +#include "secretflow_serving/util/csv_util.h" #include "secretflow_serving/util/utils.h" -namespace secretflow::serving { +namespace secretflow::serving::csv { namespace { @@ -123,7 +124,7 @@ class ArrayReorderVisitor : public arrow::ArrayVisitor { }; std::shared_ptr ReorderRows( - std::shared_ptr rows, + const std::shared_ptr &rows, const std::vector &order) { const auto &raw_columns = rows->columns(); std::vector> reordered_rows; @@ -140,7 +141,7 @@ std::shared_ptr ReorderRows( std::vector GetIdsFromQueryDatas( const ::google::protobuf::RepeatedPtrField &query_data) { std::vector ids; - for (auto &data : query_data) { + for (const auto &data : query_data) { ids.push_back(data); } return ids; @@ -150,7 +151,7 @@ std::vector GetIdsFromQueryDatas( CSVExtractor::CSVExtractor(const std::shared_ptr &schema, std::string filename, std::string id_column_name) - : CSVExtractor(filename, id_column_name) { + : CSVExtractor(std::move(filename), std::move(id_column_name)) { FetchTable(schema); } @@ -235,4 +236,4 @@ std::shared_ptr CSVExtractor::ExtractRows( return reordered_rows; } -} // namespace secretflow::serving +} // namespace secretflow::serving::csv diff --git a/secretflow_serving/util/csv_extractor.h b/secretflow_serving/util/csv_extractor.h index 7846889..8107fa9 100644 --- a/secretflow_serving/util/csv_extractor.h +++ b/secretflow_serving/util/csv_extractor.h @@ -14,16 +14,12 @@ #pragma once -#include - #include "arrow/api.h" -#include "arrow/csv/api.h" -#include "arrow/io/api.h" #include "google/protobuf/repeated_field.h" #include "secretflow_serving/spis/batch_feature_service.pb.h" -namespace secretflow::serving { +namespace secretflow::serving::csv { class CSVExtractor { public: @@ -61,4 +57,4 @@ class CSVExtractor { std::map> schema_tables_cache_; }; -} // namespace secretflow::serving +} // namespace secretflow::serving::csv diff --git a/secretflow_serving/util/csv_extractor_test.cc b/secretflow_serving/util/csv_extractor_test.cc index 1cd39c2..93f9f7b 100644 --- a/secretflow_serving/util/csv_extractor_test.cc +++ b/secretflow_serving/util/csv_extractor_test.cc @@ -16,7 +16,6 @@ #include #include -#include #include #include @@ -31,7 +30,7 @@ #include "secretflow_serving/util/arrow_helper.h" #include "secretflow_serving/util/utils.h" -namespace secretflow::serving { +namespace secretflow::serving::csv { TEST(CSVExtractor, TestReadCsvFile) { butil::TempFile tmpfile; @@ -64,4 +63,4 @@ id,x1,x2,x3,x4 EXPECT_EQ(int_col->Value(3), 1); } -} // namespace secretflow::serving +} // namespace secretflow::serving::csv diff --git a/secretflow_serving/util/csv_util.cc b/secretflow_serving/util/csv_util.cc new file mode 100644 index 0000000..5fc97ee --- /dev/null +++ b/secretflow_serving/util/csv_util.cc @@ -0,0 +1,99 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/util/csv_util.h" + +#include "arrow/io/api.h" + +#include "secretflow_serving/util/arrow_helper.h" + +namespace secretflow::serving::csv { + +std::shared_ptr BuildStreamingReader( + const std::string& path, + std::unordered_map> col_types, + const arrow::csv::ReadOptions& read_opts) { + std::shared_ptr in_file; + SERVING_GET_ARROW_RESULT(arrow::io::ReadableFile::Open(path), in_file); + arrow::csv::ConvertOptions convert_options; + convert_options.column_types = std::move(col_types); + std::transform(convert_options.column_types.begin(), + convert_options.column_types.end(), + std::back_inserter(convert_options.include_columns), + [](const auto& p) { return p.first; }); + std::shared_ptr csv_reader; + SERVING_GET_ARROW_RESULT( + arrow::csv::StreamingReader::Make( + arrow::io::default_io_context(), in_file, read_opts, + arrow::csv::ParseOptions::Defaults(), convert_options), + csv_reader); + return csv_reader; +} + +std::shared_ptr BuildeStreamingWriter( + const std::string& path, const std::shared_ptr& schema) { + std::shared_ptr out_stream; + SERVING_GET_ARROW_RESULT(arrow::io::FileOutputStream::Open(path), out_stream); + auto writer_opts = arrow::csv::WriteOptions::Defaults(); + writer_opts.quoting_style = arrow::csv::QuotingStyle::None; + std::shared_ptr csv_writer; + SERVING_GET_ARROW_RESULT( + arrow::csv::MakeCSVWriter(out_stream, schema, writer_opts), csv_writer); + + return csv_writer; +} + +std::shared_ptr ReadCsvFileToTable( + const std::string& path, + const std::shared_ptr& feature_schema) { + // read csv file + std::shared_ptr file; + SERVING_GET_ARROW_RESULT(arrow::io::ReadableFile::Open(path), file); + + arrow::csv::ConvertOptions convert_options; + + for (int i = 0; i < feature_schema->num_fields(); ++i) { + std::shared_ptr field = feature_schema->field(i); + + convert_options.include_columns.push_back(field->name()); + convert_options.column_types[field->name()] = field->type(); + } + + std::shared_ptr csv_reader; + SERVING_GET_ARROW_RESULT( + arrow::csv::TableReader::Make(arrow::io::default_io_context(), file, + arrow::csv::ReadOptions::Defaults(), + arrow::csv::ParseOptions::Defaults(), + convert_options), + csv_reader); + + std::shared_ptr table; + SERVING_GET_ARROW_RESULT(csv_reader->Read(), table); + return table; +} + +std::shared_ptr GetIdColumnFromFile( + const std::string& filename, const std::string& id_name) { + std::vector> fields; + fields.push_back(arrow::field(id_name, arrow::utf8())); + auto schema = arrow::schema(fields); + auto table = ReadCsvFileToTable(filename, schema); + auto id_column = table->GetColumnByName(id_name); + SERVING_ENFORCE(id_column, errors::ErrorCode::INVALID_ARGUMENT, + "column: {} is not in csv file: {}", id_name, filename); + + return id_column; +} + +} // namespace secretflow::serving::csv diff --git a/secretflow_serving/util/csv_util.h b/secretflow_serving/util/csv_util.h new file mode 100644 index 0000000..29759a3 --- /dev/null +++ b/secretflow_serving/util/csv_util.h @@ -0,0 +1,38 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "arrow/csv/api.h" +#include "arrow/ipc/writer.h" + +namespace secretflow::serving::csv { + +std::shared_ptr BuildStreamingReader( + const std::string& path, + std::unordered_map> col_types, + const arrow::csv::ReadOptions& read_opts = + arrow::csv::ReadOptions::Defaults()); + +std::shared_ptr BuildeStreamingWriter( + const std::string& path, const std::shared_ptr& schema); + +std::shared_ptr ReadCsvFileToTable( + const std::string& path, + const std::shared_ptr& feature_schema); + +std::shared_ptr GetIdColumnFromFile( + const std::string& filename, const std::string& id_name); + +} // namespace secretflow::serving::csv diff --git a/secretflow_serving/util/he_mgm.cc b/secretflow_serving/util/he_mgm.cc new file mode 100644 index 0000000..9e1e49f --- /dev/null +++ b/secretflow_serving/util/he_mgm.cc @@ -0,0 +1,134 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "secretflow_serving/util/he_mgm.h" + +#include "secretflow_serving/core/exception.h" + +namespace secretflow::serving::he { + +void HeKitMgm::InitLocalKit(yacl::ByteContainerView pk_buffer, + yacl::ByteContainerView sk_buffer, + int64_t encode_scale) { + encode_scale_ = encode_scale; + local_kit_ = std::make_unique(pk_buffer, sk_buffer); + local_matrix_kit_ = std::make_unique(*local_kit_); + default_encoder_ = std::make_shared( + local_kit_->GetSchemaType(), encode_scale_); +} + +void HeKitMgm::InitDstKit(const std::string& party, + yacl::ByteContainerView pk_buffer) { + heu_phe::DestinationHeKit dst_kit(pk_buffer); + heu_matrix::DestinationHeKit dst_matrix_kit(dst_kit); + dst_kit_map_.emplace( + party, + std::make_pair( + std::move(dst_kit), std::move(dst_matrix_kit))); +} + +int64_t HeKitMgm::GetEncodeScale() { + SERVING_ENFORCE(local_kit_, errors::ErrorCode::LOGIC_ERROR); + + return encode_scale_; +} + +heu_phe::SchemaType HeKitMgm::GetSchemaType() { + SERVING_ENFORCE(local_kit_, errors::ErrorCode::LOGIC_ERROR); + + return local_kit_->GetSchemaType(); +} + +std::shared_ptr HeKitMgm::GetEncoder(int64_t scale) { + SERVING_ENFORCE(local_kit_, errors::ErrorCode::LOGIC_ERROR); + + return std::make_shared(local_kit_->GetSchemaType(), + scale); +} + +std::shared_ptr HeKitMgm::GetEncoder() { + SERVING_ENFORCE(default_encoder_, errors::ErrorCode::LOGIC_ERROR); + return default_encoder_; +} + +const std::shared_ptr& HeKitMgm::GetLocalEncryptor() { + SERVING_ENFORCE(local_kit_, errors::ErrorCode::LOGIC_ERROR); + + return local_kit_->GetEncryptor(); +} + +const std::shared_ptr& HeKitMgm::GetLocalEvaluator() { + SERVING_ENFORCE(local_kit_, errors::ErrorCode::LOGIC_ERROR); + + return local_kit_->GetEvaluator(); +} + +const std::shared_ptr& HeKitMgm::GetLocalDecryptor() { + SERVING_ENFORCE(local_kit_, errors::ErrorCode::LOGIC_ERROR); + + return local_kit_->GetDecryptor(); +} + +const std::shared_ptr& HeKitMgm::GetDstEncryptor( + const std::string& party) { + if (auto it = dst_kit_map_.find(party); it != dst_kit_map_.end()) { + return it->second.first.GetEncryptor(); + } + SERVING_THROW(errors::ErrorCode::LOGIC_ERROR, + "can not find he kit for party: {}", party); +} + +const std::shared_ptr& HeKitMgm::GetDstEvaluator( + const std::string& party) { + if (auto it = dst_kit_map_.find(party); it != dst_kit_map_.end()) { + return it->second.first.GetEvaluator(); + } + SERVING_THROW(errors::ErrorCode::LOGIC_ERROR, + "can not find he kit for party: {}", party); +} + +const std::shared_ptr& +HeKitMgm::GetLocalMatrixEncryptor() { + return local_matrix_kit_->GetEncryptor(); +} + +const std::shared_ptr& +HeKitMgm::GetLocalMatrixEvaluator() { + return local_matrix_kit_->GetEvaluator(); +} + +const std::shared_ptr& +HeKitMgm::GetLocalMatrixDecryptor() { + return local_matrix_kit_->GetDecryptor(); +} + +const std::shared_ptr& HeKitMgm::GetDstMatrixEncryptor( + const std::string& party) { + if (auto it = dst_kit_map_.find(party); it != dst_kit_map_.end()) { + return it->second.second.GetEncryptor(); + } + SERVING_THROW(errors::ErrorCode::LOGIC_ERROR, + "can not find he kit for party: {}", party); +} + +const std::shared_ptr& HeKitMgm::GetDstMatrixEvaluator( + const std::string& party) { + if (auto it = dst_kit_map_.find(party); it != dst_kit_map_.end()) { + return it->second.second.GetEvaluator(); + } + SERVING_THROW(errors::ErrorCode::LOGIC_ERROR, + "can not find he kit for party: {}", party); +} + +} // namespace secretflow::serving::he diff --git a/secretflow_serving/util/he_mgm.h b/secretflow_serving/util/he_mgm.h new file mode 100644 index 0000000..8317c82 --- /dev/null +++ b/secretflow_serving/util/he_mgm.h @@ -0,0 +1,86 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/numpy/numpy.h" +#include "heu/library/phe/encoding/plain_encoder.h" +#include "heu/library/phe/phe.h" + +#include "secretflow_serving/core/singleton.h" + +namespace heu_phe = ::heu::lib::phe; +namespace heu_matrix = ::heu::lib::numpy; + +namespace secretflow::serving::he { + +const int64_t kFeatureScale = 1e6; + +class HeKitMgm : public Singleton { + public: + void InitLocalKit(yacl::ByteContainerView pk_buffer, + yacl::ByteContainerView sk_buffer, int64_t encode_scale); + void InitDstKit(const std::string& party, yacl::ByteContainerView pk_buffer); + + int64_t GetEncodeScale(); + + heu_phe::SchemaType GetSchemaType(); + + std::shared_ptr GetEncoder(int64_t scale); + + std::shared_ptr GetEncoder(); + + // scalar + [[nodiscard]] const std::shared_ptr& GetLocalEncryptor(); + + [[nodiscard]] const std::shared_ptr& GetLocalEvaluator(); + + [[nodiscard]] const std::shared_ptr& GetLocalDecryptor(); + + [[nodiscard]] const std::shared_ptr& GetDstEncryptor( + const std::string& party); + + [[nodiscard]] const std::shared_ptr& GetDstEvaluator( + const std::string& party); + + // matrix + [[nodiscard]] const std::shared_ptr& + GetLocalMatrixEncryptor(); + + [[nodiscard]] const std::shared_ptr& + GetLocalMatrixEvaluator(); + + [[nodiscard]] const std::shared_ptr& + GetLocalMatrixDecryptor(); + + [[nodiscard]] const std::shared_ptr& + GetDstMatrixEncryptor(const std::string& party); + + [[nodiscard]] const std::shared_ptr& + GetDstMatrixEvaluator(const std::string& party); + + private: + int64_t encode_scale_ = 0; + + std::unique_ptr local_kit_; + std::unique_ptr local_matrix_kit_; + + std::map> + dst_kit_map_; + + std::shared_ptr default_encoder_; +}; + +} // namespace secretflow::serving::he diff --git a/secretflow_serving/util/network.cc b/secretflow_serving/util/network.cc index 579f5da..8e8b327 100644 --- a/secretflow_serving/util/network.cc +++ b/secretflow_serving/util/network.cc @@ -26,6 +26,9 @@ const std::string kHttpPrefix = "http://"; const std::string kHttpsPrefix = "https://"; const std::string kDefaultLoadBalancer = "rr"; +const int32_t kPeerConnectTimeoutMs = 500; +const int32_t kPeerRpcTimeoutMs = 2000; + std::string FillHttpPrefix(const std::string& addr, bool ssl_enabled) { if (absl::StartsWith(addr, kHttpPrefix) || absl::StartsWith(addr, kHttpsPrefix)) { @@ -104,9 +107,46 @@ std::unique_ptr CreateBrpcChannel( RetryPolicyFactory::GetInstance()->SetConfig(name, retry_policy_config); opts.retry_policy = RetryPolicyFactory::GetInstance()->GetRetryPolicy(name); + opts.max_retry = RetryPolicyFactory::GetInstance()->GetMaxRetryCount(name); return CreateBrpcChannel(endpoint, protocol, enable_lb, rpc_timeout_ms, connect_timeout_ms, tls_config, opts); } +std::shared_ptr< + std::map>> +BuildChannelsFromConfig(const ClusterConfig& cluster_config, + bool enable_peers_load_balancer) { + SERVING_ENFORCE(cluster_config.parties_size() > 1, + errors::ErrorCode::INVALID_ARGUMENT, + "too few parties params for cluster config, get: {}", + cluster_config.parties_size()); + + auto channels = std::make_shared< + std::map>>(); + for (const auto& party : cluster_config.parties()) { + if (party.id() == cluster_config.self_id()) { + continue; + } + const auto& channel_desc = cluster_config.channel_desc(); + channels->emplace( + party.id(), + CreateBrpcChannel( + party.id(), party.address(), channel_desc.protocol(), + enable_peers_load_balancer, + channel_desc.rpc_timeout_ms() > 0 ? channel_desc.rpc_timeout_ms() + : kPeerRpcTimeoutMs, + channel_desc.connect_timeout_ms() > 0 + ? channel_desc.connect_timeout_ms() + : kPeerConnectTimeoutMs, + channel_desc.has_tls_config() ? &channel_desc.tls_config() + : nullptr, + channel_desc.has_retry_policy_config() + ? &channel_desc.retry_policy_config() + : nullptr)); + } + + return channels; +} + } // namespace secretflow::serving diff --git a/secretflow_serving/util/network.h b/secretflow_serving/util/network.h index 09ca24f..8f3b789 100644 --- a/secretflow_serving/util/network.h +++ b/secretflow_serving/util/network.h @@ -21,6 +21,7 @@ #include "secretflow_serving/core/exception.h" +#include "secretflow_serving/config/cluster_config.pb.h" #include "secretflow_serving/config/retry_policy_config.pb.h" #include "secretflow_serving/config/tls_config.pb.h" @@ -37,4 +38,8 @@ std::unique_ptr CreateBrpcChannel( int32_t rpc_timeout_ms, int32_t connect_timeout_ms, const TlsConfig* tls_config); +std::shared_ptr< + std::map>> +BuildChannelsFromConfig(const ClusterConfig& cluster_config, + bool enable_peers_load_balancer = false); } // namespace secretflow::serving diff --git a/secretflow_serving/util/retry_policy.cc b/secretflow_serving/util/retry_policy.cc index 6590198..ade6839 100644 --- a/secretflow_serving/util/retry_policy.cc +++ b/secretflow_serving/util/retry_policy.cc @@ -29,7 +29,7 @@ const std::unordered_set KCustomRetryBrpcCode = {}; const std::unordered_set KCustomRetryHttpCode = {500, 502, 503, 504, 408, 429}; -constexpr int32_t kMaxRetryCount = 3; +constexpr int32_t kDefaultMaxRetryCount = 3; const char* KDefaultPolicyName = "__default__"; } // namespace @@ -122,7 +122,8 @@ RetryPolicyFactory::RetryPolicyFactory() { std::shared_ptr policy = std::make_shared>( false, false, static_cast(nullptr)); - retry_policies_.emplace(KDefaultPolicyName, Policy{policy, kMaxRetryCount}); + retry_policies_.emplace(KDefaultPolicyName, + Policy{policy, kDefaultMaxRetryCount}); } void RetryPolicyFactory::SetConfig(const std::string& name, @@ -137,11 +138,10 @@ void RetryPolicyFactory::SetConfig(const std::string& name, auto custom_retry = false; auto retry_aggressive = false; - int32_t max_retry_count = kMaxRetryCount; + int32_t max_retry_count = kDefaultMaxRetryCount; - if (config) { + if (config != nullptr) { custom_retry = config->retry_custom(); - retry_aggressive = config->retry_aggressive(); if (config->max_retry_count() != 0) { @@ -150,7 +150,6 @@ void RetryPolicyFactory::SetConfig(const std::string& name, switch (config->backoff_mode()) { case RetryPolicyBackOffMode::EXPONENTIAL_BACKOFF: - policy = MakeRetryPolicy(custom_retry, retry_aggressive, config->has_exponential_backoff_config(), config->exponential_backoff_config()); @@ -169,9 +168,7 @@ void RetryPolicyFactory::SetConfig(const std::string& name, break; } } - SPDLOG_INFO("Regist retry policy: name={}", name); - retry_policies_[name] = Policy{policy, max_retry_count}; } diff --git a/secretflow_serving/util/retry_policy_test.cc b/secretflow_serving/util/retry_policy_test.cc index 3ce4f6c..a834e21 100644 --- a/secretflow_serving/util/retry_policy_test.cc +++ b/secretflow_serving/util/retry_policy_test.cc @@ -153,7 +153,6 @@ void StartServerAddRequest(int http_code, int error_cnt, int retried_count, brpc::Controller cntl; EXPECT_EQ(3, policy.max_retry_count); - cntl.set_max_retry(policy.max_retry_count); cntl.http_request().uri() = server_addr + "/HttpService"; // 设置为待访问的URL channel.CallMethod(NULL, &cntl, NULL, NULL, NULL /*done*/); diff --git a/secretflow_serving/util/utils.cc b/secretflow_serving/util/utils.cc index 3959d5d..efead9e 100644 --- a/secretflow_serving/util/utils.cc +++ b/secretflow_serving/util/utils.cc @@ -112,4 +112,26 @@ size_t CountSampleNum( return predefined_row_num; } +std::string UnescapeJson(const std::string& json) { + std::string result; + bool escape = false; + for (char ch : json) { + if (escape) { + result += ch; + escape = false; + } else if (ch == '\\') { + escape = true; + } else { + result += ch; + } + } + return result; +} + +bool CheckContentEmpty(const std::string& str) { + return str.empty() || + std::all_of(str.begin(), str.end(), + [](unsigned char c) { return std::isspace(c); }); +} + } // namespace secretflow::serving diff --git a/secretflow_serving/util/utils.h b/secretflow_serving/util/utils.h index 68eae47..af36eda 100644 --- a/secretflow_serving/util/utils.h +++ b/secretflow_serving/util/utils.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "secretflow_serving/core/exception.h" #include "secretflow_serving/apis/error_code.pb.h" @@ -77,4 +79,35 @@ void FeatureVisit(Func&& visitor, const Feature& f) { size_t CountSampleNum( const ::google::protobuf::RepeatedPtrField& features); +std::string UnescapeJson(const std::string& json); + +bool CheckContentEmpty(const std::string& str); + +class RetryRunner { + public: + RetryRunner(uint32_t retry_counts, uint32_t retry_interval_ms) + : retry_counts_(retry_counts), retry_interval_ms_(retry_interval_ms) {} + + template >>> + bool Run(Func&& f, Args&&... args) const { + auto runner_func = [&] { + return std::invoke(std::forward(f), std::forward(args)...); + }; + for (uint32_t i = 0; i != retry_counts_; ++i) { + if (!runner_func()) { + std::this_thread::sleep_for( + std::chrono::milliseconds(retry_interval_ms_)); + } else { + return true; + } + } + return false; + } + + private: + uint32_t retry_counts_; + uint32_t retry_interval_ms_; +}; } // namespace secretflow::serving