From ef7c18b45c3b1b3bb66e973c0ba9333fecab6171 Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 4 Jun 2024 11:02:59 +0800 Subject: [PATCH] repo-sync-2024-05-30T19:36:59+0800 (#67) --- build_wheel_entrypoint.sh | 2 +- .../topics/deployment/serving_on_kuscia.po | 8 +- .../LC_MESSAGES/topics/graph/operator_list.po | 135 +++++++++------- .../topics/deployment/serving_on_kuscia.rst | 108 ++++++------- docs/source/topics/graph/operator_list.md | 10 +- secretflow_serving/core/link_func.cc | 34 ++++ secretflow_serving/core/link_func.h | 6 + secretflow_serving/core/singleton.h | 11 +- secretflow_serving/core/types.h | 17 +- .../feature_adapter/feature_adapter_factory.h | 2 +- .../feature_adapter/http_adapter.cc | 2 +- .../feature_adapter/http_adapter.h | 2 +- .../framework/execute_context.cc | 146 +++++++++++++----- .../framework/execute_context.h | 104 +++---------- .../framework/execute_context_test.cc | 8 +- .../framework/model_info_collector.cc | 2 +- .../framework/model_info_collector.h | 4 +- secretflow_serving/framework/predictor.cc | 59 +++---- secretflow_serving/framework/predictor.h | 11 +- .../framework/predictor_test.cc | 14 +- secretflow_serving/ops/arrow_processing.cc | 5 +- secretflow_serving/ops/graph.cc | 2 +- secretflow_serving/ops/merge_y.cc | 14 +- secretflow_serving/ops/merge_y.h | 2 + secretflow_serving/ops/merge_y_test.cc | 37 +++-- .../ops/tree_ensemble_predict.cc | 10 +- .../ops/tree_ensemble_predict.h | 2 + .../ops/tree_ensemble_predict_test.cc | 4 + secretflow_serving/protos/link_function.proto | 3 + secretflow_serving/source/http_source.h | 2 +- secretflow_serving/util/arrow_helper.cc | 11 +- secretflow_serving/util/arrow_helper.h | 11 +- secretflow_serving/util/network.cc | 6 +- secretflow_serving/util/network.h | 4 +- secretflow_serving/util/thread_pool.h | 2 +- 35 files changed, 446 insertions(+), 354 deletions(-) diff --git a/build_wheel_entrypoint.sh b/build_wheel_entrypoint.sh index 4a8d8cc..e9f3433 100644 --- a/build_wheel_entrypoint.sh +++ b/build_wheel_entrypoint.sh @@ -17,4 +17,4 @@ set -e rm -rf dist python setup.py bdist_wheel -python3 -m pip install dist/*.whl --force-reinstall +python -m pip install dist/*.whl --force-reinstall diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po index bf09796..874ee98 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/deployment/serving_on_kuscia.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-03-13 19:29+0800\n" +"POT-Creation-Date: 2024-05-22 16:27+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -314,11 +314,9 @@ msgstr "" #: ../../source/topics/deployment/serving_on_kuscia.rst:176 msgid "" "Model data source type, options include: ST_FILE: In this case, the " -"sourcePath should be a file path accessible to Serving. ST_OSS: In this " -"case, the sourcePath should be the path to the model package in OSS." +"sourcePath should be a file path accessible to Serving." msgstr "" -"模型数据源类型,可选内容: ST_FILE: 此时`sourcePath`应为文件系统路径。 ST_OSS: " -"此时`sourcePath`应为OSS存储中数据的路径。" +"模型数据源类型,可选内容: ST_FILE: 此时`sourcePath`应为文件系统路径。" #: ../../source/topics/deployment/serving_on_kuscia.rst:178 msgid "PartyConfig.featureSourceConfig" diff --git a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po index 643e082..572373e 100644 --- a/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po +++ b/docs/locales/zh_CN/LC_MESSAGES/topics/graph/operator_list.po @@ -8,21 +8,21 @@ msgid "" msgstr "" "Project-Id-Version: SecretFlow-Serving \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-03-01 17:29+0800\n" +"POT-Creation-Date: 2024-05-29 20:16+0800\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=utf-8\n" "Content-Transfer-Encoding: 8bit\n" -"Generated-By: Babel 2.14.0\n" +"Generated-By: Babel 2.15.0\n" #: ../../source/topics/graph/operator_list.md:5 msgid "SecretFlow-Serving Operator List" msgstr "SecretFlow-Serving 算子列表" #: ../../source/topics/graph/operator_list.md:9 -msgid "Last update: Fri Mar 1 17:28:55 2024" +msgid "Last update: Wed May 29 20:14:58 2024" msgstr "" #: ../../source/topics/graph/operator_list.md:10 @@ -30,8 +30,7 @@ msgid "MERGE_Y" msgstr "" #: ../../source/topics/graph/operator_list.md:13 -#: ../../source/topics/graph/operator_list.md:51 -msgid "Operator version: 0.0.2" +msgid "Operator version: 0.0.3" msgstr "" #: ../../source/topics/graph/operator_list.md:15 @@ -39,11 +38,11 @@ msgid "Merge all partial y(score) and apply link function" msgstr "" #: ../../source/topics/graph/operator_list.md:16 -#: ../../source/topics/graph/operator_list.md:54 -#: ../../source/topics/graph/operator_list.md:85 -#: ../../source/topics/graph/operator_list.md:122 -#: ../../source/topics/graph/operator_list.md:158 -#: ../../source/topics/graph/operator_list.md:194 +#: ../../source/topics/graph/operator_list.md:55 +#: ../../source/topics/graph/operator_list.md:86 +#: ../../source/topics/graph/operator_list.md:123 +#: ../../source/topics/graph/operator_list.md:159 +#: ../../source/topics/graph/operator_list.md:195 msgid "Attrs" msgstr "" @@ -67,6 +66,28 @@ msgstr "" msgid "Notes" msgstr "" +#: ../../source/topics/graph/operator_list.md +msgid "exp_iters" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "" +"Number of iterations of `exp` approximation, valid when `link_function` " +"set `LF_EXP_TAYLOR`" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Integer32" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "N" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "Default: 0." +msgstr "" + #: ../../source/topics/graph/operator_list.md msgid "output_col_name" msgstr "" @@ -91,11 +112,11 @@ msgstr "" msgid "" "Type of link function, defined in " "`secretflow_serving/protos/link_function.proto`. Optional value: LF_EXP, " -"LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, " -"LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, " -"LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, " -"LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, " -"LF_SIGMOID_SEGLS" +"LF_EXP_TAYLOR, LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, " +"LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, " +"LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, " +"LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, " +"LF_SIGMOID_SR, LF_SIGMOID_SEGLS" msgstr "" #: ../../source/topics/graph/operator_list.md @@ -121,18 +142,14 @@ msgstr "" msgid "Double" msgstr "" -#: ../../source/topics/graph/operator_list.md -msgid "N" -msgstr "" - #: ../../source/topics/graph/operator_list.md msgid "Default: 1.0." msgstr "" -#: ../../source/topics/graph/operator_list.md:26 -#: ../../source/topics/graph/operator_list.md:95 -#: ../../source/topics/graph/operator_list.md:167 -#: ../../source/topics/graph/operator_list.md:204 +#: ../../source/topics/graph/operator_list.md:27 +#: ../../source/topics/graph/operator_list.md:96 +#: ../../source/topics/graph/operator_list.md:168 +#: ../../source/topics/graph/operator_list.md:206 msgid "Tags" msgstr "" @@ -154,12 +171,12 @@ msgid "" "and will somehow merge them." msgstr "" -#: ../../source/topics/graph/operator_list.md:34 -#: ../../source/topics/graph/operator_list.md:65 -#: ../../source/topics/graph/operator_list.md:102 -#: ../../source/topics/graph/operator_list.md:138 -#: ../../source/topics/graph/operator_list.md:174 -#: ../../source/topics/graph/operator_list.md:211 +#: ../../source/topics/graph/operator_list.md:35 +#: ../../source/topics/graph/operator_list.md:66 +#: ../../source/topics/graph/operator_list.md:103 +#: ../../source/topics/graph/operator_list.md:139 +#: ../../source/topics/graph/operator_list.md:175 +#: ../../source/topics/graph/operator_list.md:213 msgid "Inputs" msgstr "" @@ -171,12 +188,12 @@ msgstr "" msgid "The list of partial y, data type: `double`" msgstr "" -#: ../../source/topics/graph/operator_list.md:41 -#: ../../source/topics/graph/operator_list.md:72 -#: ../../source/topics/graph/operator_list.md:109 -#: ../../source/topics/graph/operator_list.md:145 -#: ../../source/topics/graph/operator_list.md:181 -#: ../../source/topics/graph/operator_list.md:218 +#: ../../source/topics/graph/operator_list.md:42 +#: ../../source/topics/graph/operator_list.md:73 +#: ../../source/topics/graph/operator_list.md:110 +#: ../../source/topics/graph/operator_list.md:146 +#: ../../source/topics/graph/operator_list.md:182 +#: ../../source/topics/graph/operator_list.md:220 msgid "Output" msgstr "" @@ -188,11 +205,16 @@ msgstr "" msgid "The merge result of `partial_ys`, data type: `double`" msgstr "" -#: ../../source/topics/graph/operator_list.md:48 +#: ../../source/topics/graph/operator_list.md:49 msgid "DOT_PRODUCT" msgstr "" -#: ../../source/topics/graph/operator_list.md:53 +#: ../../source/topics/graph/operator_list.md:52 +#: ../../source/topics/graph/operator_list.md:192 +msgid "Operator version: 0.0.2" +msgstr "" + +#: ../../source/topics/graph/operator_list.md:54 msgid "Calculate the dot product of feature weights and values" msgstr "" @@ -259,18 +281,17 @@ msgstr "" msgid "The calculation results, they have a data type of `double`." msgstr "" -#: ../../source/topics/graph/operator_list.md:79 +#: ../../source/topics/graph/operator_list.md:80 msgid "ARROW_PROCESSING" msgstr "" -#: ../../source/topics/graph/operator_list.md:82 -#: ../../source/topics/graph/operator_list.md:119 -#: ../../source/topics/graph/operator_list.md:155 -#: ../../source/topics/graph/operator_list.md:191 +#: ../../source/topics/graph/operator_list.md:83 +#: ../../source/topics/graph/operator_list.md:120 +#: ../../source/topics/graph/operator_list.md:156 msgid "Operator version: 0.0.1" msgstr "" -#: ../../source/topics/graph/operator_list.md:84 +#: ../../source/topics/graph/operator_list.md:85 msgid "Replay secretflow compute functions" msgstr "" @@ -326,11 +347,11 @@ msgstr "" msgid "output" msgstr "" -#: ../../source/topics/graph/operator_list.md:116 +#: ../../source/topics/graph/operator_list.md:117 msgid "TREE_SELECT" msgstr "" -#: ../../source/topics/graph/operator_list.md:121 +#: ../../source/topics/graph/operator_list.md:122 msgid "" "Obtaining the local prediction path information of the decision tree " "using input features." @@ -394,14 +415,6 @@ msgstr "" msgid "The id of the root tree node" msgstr "" -#: ../../source/topics/graph/operator_list.md -msgid "Integer32" -msgstr "" - -#: ../../source/topics/graph/operator_list.md -msgid "Default: 0." -msgstr "" - #: ../../source/topics/graph/operator_list.md msgid "Column name of tree select" msgstr "" @@ -437,11 +450,11 @@ msgstr "" msgid "The local prediction path information of the decision tree." msgstr "" -#: ../../source/topics/graph/operator_list.md:152 +#: ../../source/topics/graph/operator_list.md:153 msgid "TREE_MERGE" msgstr "" -#: ../../source/topics/graph/operator_list.md:157 +#: ../../source/topics/graph/operator_list.md:158 msgid "" "Merge the `TREE_SELECT` output from multiple parties to obtain a unique " "prediction path and return the result weights." @@ -485,17 +498,25 @@ msgstr "" msgid "The prediction result of tree." msgstr "" -#: ../../source/topics/graph/operator_list.md:188 +#: ../../source/topics/graph/operator_list.md:189 msgid "TREE_ENSEMBLE_PREDICT" msgstr "" -#: ../../source/topics/graph/operator_list.md:193 +#: ../../source/topics/graph/operator_list.md:194 msgid "" "Accept the weighted results from multiple trees (`TREE_SELECT` + " "`TREE_MERGE`), merge them, and obtain the final prediction result of the " "tree ensemble." msgstr "" +#: ../../source/topics/graph/operator_list.md +msgid "base_score" +msgstr "" + +#: ../../source/topics/graph/operator_list.md +msgid "The initial prediction score, global bias." +msgstr "" + #: ../../source/topics/graph/operator_list.md msgid "num_trees" msgstr "" diff --git a/docs/source/topics/deployment/serving_on_kuscia.rst b/docs/source/topics/deployment/serving_on_kuscia.rst index e9d52a9..086e674 100644 --- a/docs/source/topics/deployment/serving_on_kuscia.rst +++ b/docs/source/topics/deployment/serving_on_kuscia.rst @@ -57,16 +57,16 @@ To deploy SecretFlow-Serving in Kusica, you first need to register the template readinessProbe: httpGet: path: /health - port: 53511 + port: brpc-builtin livenessProbe: httpGet: path: /health - port: 53511 + port: brpc-builtin startupProbe: failureThreshold: 30 httpGet: path: /health - port: 53511 + port: brpc-builtin periodSeconds: 10 successThreshold: 1 timeoutSeconds: 1 @@ -154,54 +154,54 @@ The launch and management of SecretFlow-Serving can be performed using the `Kusc **Field description**: -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| Name | Type | Description | Required | -+===========================================================+=======================+==================================================================================================================================================================================================================+========================================================================+ -| partyConfigs | map | Dictionary of startup parameters for each participant. Key: Participant Unique ID; Value: PartyConfig (Json Object). | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.serverConfig | str | :ref:`ServerConfig ` | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.serverConfig.featureMapping | map | Feature name mapping rules. Key: source or predefined feature name; Value: model feature name | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig | Object | :ref:`ModelConfig ` | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.modelId | str | Unique id of the model package | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.basePath | str | The local path used to cache and load model package | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.sourcePath | str | The path to the model package in the data source, where the content format may vary depending on the `sourceType`. | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.source_sha256 | str | The expected SHA256 hash of the model package. When provided, the fetched model package will be verified against it. | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.modelConfig.sourceType | str | Model data source type, options include: ST_FILE: In this case, the sourcePath should be a file path accessible to Serving. ST_OSS: In this case, the sourcePath should be the path to the model package in OSS. | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig | Object | :ref:`FeatureSourceConfig ` | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.mockOpts | Object | :ref:`MockOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.mockOpts.type | str | The method for generating mock feature values, options: "MDT_RANDOM" for random values, and "MDT_FIXED" for fixed values. Default: "MDT_FIXED". | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts | Object | :ref:`HttpOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.endpoint | str | Feature service address | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.enableLb | bool | Whether to enable round robin load balancer, Default: False | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.httpOpts.timeoutMs | int32 | Max duration of http request. -1 means wait indefinitely. Default: 1000 (ms) | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.csvOpts | Object | :ref:`CsvOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.csvOpts.file_path | Object | Input file path, specifies where to load data. Note that this will load all of the data into memory at once | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.featureSourceConfig.csvOpts.id_name | Object | Id column name, associated with `FeatureParam::query_datas`. `query_datas` is a subset of id column | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc | Object | :ref:`ChannelDesc ` | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc.protocol | str | Communication protocol, for optional value, see `here `_ | Yes | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc.rpcTimeoutMs | int32 | Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms) | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ -| PartyConfig.channelDesc.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | -+-----------------------------------------------------------+-----------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| Name | Type | Description | Required | ++===========================================================+=======================+=================================================================================================================================================+========================================================================+ +| partyConfigs | map | Dictionary of startup parameters for each participant. Key: Participant Unique ID; Value: PartyConfig (Json Object). | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.serverConfig | str | :ref:`ServerConfig ` | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.serverConfig.featureMapping | map | Feature name mapping rules. Key: source or predefined feature name; Value: model feature name | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.modelConfig | Object | :ref:`ModelConfig ` | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.modelConfig.modelId | str | Unique id of the model package | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.modelConfig.basePath | str | The local path used to cache and load model package | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.modelConfig.sourcePath | str | The path to the model package in the data source, where the content format may vary depending on the `sourceType`. | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.modelConfig.source_sha256 | str | The expected SHA256 hash of the model package. When provided, the fetched model package will be verified against it. | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.modelConfig.sourceType | str | Model data source type, options include: ST_FILE: In this case, the sourcePath should be a file path accessible to Serving. | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig | Object | :ref:`FeatureSourceConfig ` | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.mockOpts | Object | :ref:`MockOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.mockOpts.type | str | The method for generating mock feature values, options: "MDT_RANDOM" for random values, and "MDT_FIXED" for fixed values. Default: "MDT_FIXED". | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts | Object | :ref:`HttpOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.endpoint | str | Feature service address | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.enableLb | bool | Whether to enable round robin load balancer, Default: False | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.httpOpts.timeoutMs | int32 | Max duration of http request. -1 means wait indefinitely. Default: 1000 (ms) | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.csvOpts | Object | :ref:`CsvOptions ` | No(One of `csvOpts`, `mockOpts`, or `httpOpts` needs to be configured) | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.csvOpts.file_path | Object | Input file path, specifies where to load data. Note that this will load all of the data into memory at once | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.featureSourceConfig.csvOpts.id_name | Object | Id column name, associated with `FeatureParam::query_datas`. `query_datas` is a subset of id column | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.channelDesc | Object | :ref:`ChannelDesc ` | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.channelDesc.protocol | str | Communication protocol, for optional value, see `here `_ | Yes | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.channelDesc.rpcTimeoutMs | int32 | Max duration of RPC. -1 means wait indefinitely. Default: 2000 (ms) | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ +| PartyConfig.channelDesc.connectTimeoutMs | int32 | Max duration for a connect. -1 means wait indefinitely. Default: 500 (ms) | No | ++-----------------------------------------------------------+-----------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------+ diff --git a/docs/source/topics/graph/operator_list.md b/docs/source/topics/graph/operator_list.md index 75184b5..95c91a9 100644 --- a/docs/source/topics/graph/operator_list.md +++ b/docs/source/topics/graph/operator_list.md @@ -6,11 +6,11 @@ SecretFlow-Serving Operator List ================================ -Last update: Fri Mar 1 17:28:55 2024 +Last update: Wed May 29 20:14:58 2024 ## MERGE_Y -Operator version: 0.0.2 +Operator version: 0.0.3 Merge all partial y(score) and apply link function ### Attrs @@ -18,8 +18,9 @@ Merge all partial y(score) and apply link function |Name|Description|Type|Required|Notes| | :--- | :--- | :--- | :--- | :--- | +|exp_iters|Number of iterations of `exp` approximation, valid when `link_function` set `LF_EXP_TAYLOR`|Integer32|N|Default: 0.| |output_col_name|The column name of merged score|String|Y|| -|link_function|Type of link function, defined in `secretflow_serving/protos/link_function.proto`. Optional value: LF_EXP, LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS|String|Y|| +|link_function|Type of link function, defined in `secretflow_serving/protos/link_function.proto`. Optional value: LF_EXP, LF_EXP_TAYLOR, LF_RECIPROCAL, LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS|String|Y|| |input_col_name|The column name of partial_y|String|Y|| |yhat_scale|In order to prevent value overflow, GLM training is performed on the scaled y label. So in the prediction process, you need to enlarge yhat back to get the real predicted value, `yhat = yhat_scale * link(X * W)`|Double|N|Default: 1.0.| @@ -188,7 +189,7 @@ Merge the `TREE_SELECT` output from multiple parties to obtain a unique predicti ## TREE_ENSEMBLE_PREDICT -Operator version: 0.0.1 +Operator version: 0.0.2 Accept the weighted results from multiple trees (`TREE_SELECT` + `TREE_MERGE`), merge them, and obtain the final prediction result of the tree ensemble. ### Attrs @@ -196,6 +197,7 @@ Accept the weighted results from multiple trees (`TREE_SELECT` + `TREE_MERGE`), |Name|Description|Type|Required|Notes| | :--- | :--- | :--- | :--- | :--- | +|base_score|The initial prediction score, global bias.|Double|N|Default: 0.0.| |num_trees|The number of ensemble's tree|Integer32|Y|| |output_col_name|The column name of tree ensemble predict score|String|Y|| |algo_func|Optional value: LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, LF_SIGMOID_GA, LF_SIGMOID_T1, LF_SIGMOID_T3, LF_SIGMOID_T5, LF_SIGMOID_T7, LF_SIGMOID_T9, LF_SIGMOID_LS7, LF_SIGMOID_SEG3, LF_SIGMOID_SEG5, LF_SIGMOID_DF, LF_SIGMOID_SR, LF_SIGMOID_SEGLS|String|N|Default: LF_IDENTITY.| diff --git a/secretflow_serving/core/link_func.cc b/secretflow_serving/core/link_func.cc index aa80e15..6ac2ee1 100644 --- a/secretflow_serving/core/link_func.cc +++ b/secretflow_serving/core/link_func.cc @@ -54,6 +54,18 @@ T SRSig(T x) { return 0.5F * (x / std::sqrt(1.0F + std::pow(x, 2))) + 0.5F; } +// see https://lvdmaaten.github.io/publications/papers/crypten.pdf +// exp(x) = (1 + x / n) ^ n, when n is infinite large. +template +T ExpTaylor(T x, int32_t n) { + SERVING_ENFORCE_GT(n, 0); + return std::pow(1.0F + x / n, n); +} + +void ExpItersValidator(int32_t n) { + SERVING_ENFORCE_GT(n, 0, "exp_iters should be greater than 0"); +} + } // namespace LinkFunctionType ParseLinkFuncType(const std::string& type) { @@ -64,6 +76,16 @@ LinkFunctionType ParseLinkFuncType(const std::string& type) { return lf_type; } +template +void CheckLinkFuncAragsValid(LinkFunctionType lf_type, ARGS&&... args) { + if (lf_type == LF_EXP_TAYLOR) { + ExpItersValidator(std::forward(args)...); + } +} + +template void CheckLinkFuncAragsValid(LinkFunctionType, int32_t&); +template void CheckLinkFuncAragsValid(LinkFunctionType, int32_t&&); + template T ApplyLinkFunc(T x, LinkFunctionType lf_type) { auto ls7 = [](T x) -> T { @@ -164,7 +186,19 @@ T ApplyLinkFunc(T x, LinkFunctionType lf_type) { } } +template +T ApplyLinkFunc(T x, LinkFunctionType lf_type, ARGS&&... args) { + if (lf_type == LinkFunctionType::LF_EXP_TAYLOR) { + return ExpTaylor(x, std::forward(args)...); + } + return ApplyLinkFunc(x, lf_type); +} + template float ApplyLinkFunc(float, LinkFunctionType); +template float ApplyLinkFunc(float, LinkFunctionType, int32_t&&); +template float ApplyLinkFunc(float, LinkFunctionType, int32_t&); template double ApplyLinkFunc(double, LinkFunctionType); +template double ApplyLinkFunc(double, LinkFunctionType, int32_t&&); +template double ApplyLinkFunc(double, LinkFunctionType, int32_t&); } // namespace secretflow::serving diff --git a/secretflow_serving/core/link_func.h b/secretflow_serving/core/link_func.h index 976146b..2baec4d 100644 --- a/secretflow_serving/core/link_func.h +++ b/secretflow_serving/core/link_func.h @@ -25,4 +25,10 @@ LinkFunctionType ParseLinkFuncType(const std::string& type); template T ApplyLinkFunc(T x, LinkFunctionType type); +template +T ApplyLinkFunc(T x, LinkFunctionType lf_type, ARGS&&... args); + +template +void CheckLinkFuncAragsValid(LinkFunctionType lf_type, ARGS&&... args); + } // namespace secretflow::serving diff --git a/secretflow_serving/core/singleton.h b/secretflow_serving/core/singleton.h index c3c0050..25c14ac 100644 --- a/secretflow_serving/core/singleton.h +++ b/secretflow_serving/core/singleton.h @@ -19,18 +19,19 @@ namespace secretflow::serving { template class Singleton { public: - static T *GetInstance() { + static T* GetInstance() { static T t; return &t; } + Singleton(const Singleton&) = delete; + Singleton& operator=(const Singleton&) = delete; + Singleton(Singleton&&) = delete; + Singleton& operator=(Singleton&&) = delete; + protected: Singleton() {} virtual ~Singleton() {} - Singleton(const Singleton &) = delete; - Singleton &operator=(const Singleton &) = delete; - Singleton(Singleton &&) = delete; - Singleton &operator=(Singleton &&) = delete; }; } // namespace secretflow::serving diff --git a/secretflow_serving/core/types.h b/secretflow_serving/core/types.h index a15f7db..ce7fa52 100644 --- a/secretflow_serving/core/types.h +++ b/secretflow_serving/core/types.h @@ -20,18 +20,13 @@ namespace secretflow::serving { template struct TypeTrait { - typedef T Scalar; - typedef Eigen::Matrix Matrix; - typedef Eigen::Matrix ColVec; - typedef Eigen::Matrix RowVec; + using Scalar = T; + using Matrix = Eigen::Matrix; + using ColVec = Eigen::Matrix; + using RowVec = Eigen::Matrix; }; -typedef TypeTrait Float; -typedef TypeTrait Double; - -enum class AlgorithmType { - kRegression, - kClassification, -}; +using Float = TypeTrait; +using Double = TypeTrait; } // namespace secretflow::serving diff --git a/secretflow_serving/feature_adapter/feature_adapter_factory.h b/secretflow_serving/feature_adapter/feature_adapter_factory.h index cc306b8..d83d4ed 100644 --- a/secretflow_serving/feature_adapter/feature_adapter_factory.h +++ b/secretflow_serving/feature_adapter/feature_adapter_factory.h @@ -62,7 +62,7 @@ class FeatureAdapterFactory : public Singleton { template class Register { public: - Register(FeatureSourceConfig::OptionsCase opts_case) { + explicit Register(FeatureSourceConfig::OptionsCase opts_case) { FeatureAdapterFactory::GetInstance()->Register(opts_case); } }; diff --git a/secretflow_serving/feature_adapter/http_adapter.cc b/secretflow_serving/feature_adapter/http_adapter.cc index a5c379d..c077026 100644 --- a/secretflow_serving/feature_adapter/http_adapter.cc +++ b/secretflow_serving/feature_adapter/http_adapter.cc @@ -147,7 +147,7 @@ void HttpFeatureAdapter::OnFetchFeature(const Request& request, SetSpanAttrs(span, span_option); SERVING_ENFORCE(span_option.code == errors::ErrorCode::OK, span_option.code, - span_option.msg); + "{}", span_option.msg); response->header->mutable_data()->swap( *spi_response.mutable_header()->mutable_data()); response->features = diff --git a/secretflow_serving/feature_adapter/http_adapter.h b/secretflow_serving/feature_adapter/http_adapter.h index 6505dc7..e63a7bc 100644 --- a/secretflow_serving/feature_adapter/http_adapter.h +++ b/secretflow_serving/feature_adapter/http_adapter.h @@ -38,7 +38,7 @@ class HttpFeatureAdapter : public FeatureAdapter { const Request& request); protected: - std::shared_ptr channel_; + std::unique_ptr channel_; std::vector feature_fields_; }; diff --git a/secretflow_serving/framework/execute_context.cc b/secretflow_serving/framework/execute_context.cc index e221b9d..ff68f3d 100644 --- a/secretflow_serving/framework/execute_context.cc +++ b/secretflow_serving/framework/execute_context.cc @@ -19,65 +19,60 @@ namespace secretflow::serving { -void ExecuteContext::CheckAndUpdateResponse() { - CheckAndUpdateResponse(exec_res_); -} - -void ExecuteContext::CheckAndUpdateResponse( - const apis::ExecuteResponse& exec_res) { - if (!CheckStatusOk(exec_res.status())) { - SERVING_THROW( - exec_res.status().code(), - fmt::format("{} exec failed: code({}), {}", target_id_, - exec_res.status().code(), exec_res.status().msg())); - } - MergeResonseHeader(exec_res); -} - -void ExecuteContext::MergeResonseHeader() { MergeResonseHeader(exec_res_); } - -void ExecuteContext::MergeResonseHeader(const apis::ExecuteResponse& exec_res) { - response_->mutable_header()->mutable_data()->insert( - exec_res.header().data().begin(), exec_res.header().data().end()); -} - void ExeResponseToIoMap( apis::ExecuteResponse& exec_res, - std::unordered_map>* - node_io_map) { - auto result = exec_res.mutable_result(); + std::unordered_map* node_io_map) { + auto* result = exec_res.mutable_result(); for (int i = 0; i < result->nodes_size(); ++i) { - auto result_node_io = result->mutable_nodes(i); + auto* result_node_io = result->mutable_nodes(i); auto prev_insert_iter = node_io_map->find(result_node_io->name()); if (prev_insert_iter != node_io_map->end()) { // found node, merge ios auto& target_node_io = prev_insert_iter->second; - SERVING_ENFORCE(target_node_io->ios_size() == result_node_io->ios_size(), + SERVING_ENFORCE(target_node_io.ios_size() == result_node_io->ios_size(), errors::ErrorCode::LOGIC_ERROR); - for (int io_index = 0; io_index < target_node_io->ios_size(); - ++io_index) { - auto target_io = target_node_io->mutable_ios(io_index); - auto io = result_node_io->mutable_ios(io_index); + for (int io_index = 0; io_index < target_node_io.ios_size(); ++io_index) { + auto* target_io = target_node_io.mutable_ios(io_index); + auto* io = result_node_io->mutable_ios(io_index); for (int data_index = 0; data_index < io->datas_size(); ++data_index) { target_io->add_datas(std::move(*(io->mutable_datas(data_index)))); } } } else { auto node_name = result_node_io->name(); - node_io_map->emplace(node_name, std::make_shared( - std::move(*result_node_io))); + node_io_map->emplace(node_name, std::move(*result_node_io)); } } } +void ExecuteContext::CheckAndUpdateResponse() { + CheckAndUpdateResponse(exec_res_); +} + +void ExecuteContext::CheckAndUpdateResponse( + const apis::ExecuteResponse& exec_res) { + if (!CheckStatusOk(exec_res.status())) { + SERVING_THROW(exec_res.status().code(), "{} exec failed: code({}), {}", + target_id_, exec_res.status().code(), + exec_res.status().msg()); + } + MergeResonseHeader(exec_res); +} + +void ExecuteContext::MergeResonseHeader() { MergeResonseHeader(exec_res_); } + +void ExecuteContext::MergeResonseHeader(const apis::ExecuteResponse& exec_res) { + response_->mutable_header()->mutable_data()->insert( + exec_res.header().data().begin(), exec_res.header().data().end()); +} + void ExecuteContext::GetResultNodeIo( - std::unordered_map>* - node_io_map) { + std::unordered_map* node_io_map) { ExeResponseToIoMap(exec_res_, node_io_map); } void ExecuteContext::SetFeatureSource() { - auto feature_source = exec_req_.mutable_feature_source(); + auto* feature_source = exec_req_.mutable_feature_source(); if (execution_->IsEntry()) { // entry execution need features // get target_id's feature param @@ -115,17 +110,36 @@ ExecuteContext::ExecuteContext(const apis::PredictRequest* request, SetFeatureSource(); } -void ExecuteContext::Execute( - std::shared_ptr<::google::protobuf::RpcChannel> channel, - brpc::Controller* cntl) { - apis::ExecutionService_Stub stub(channel.get()); +void ExecuteContext::Execute(::google::protobuf::RpcChannel* channel, + brpc::Controller* cntl) { + apis::ExecutionService_Stub stub(channel); stub.Execute(cntl, &exec_req_, &exec_res_, brpc::DoNothing()); } -void ExecuteContext::Execute(std::shared_ptr execution_core) { +void ExecuteContext::Execute(std::shared_ptr& execution_core) { execution_core->Execute(&exec_req_, &exec_res_); } +RemoteExecute::RemoteExecute(const apis::PredictRequest* request, + apis::PredictResponse* response, + const std::shared_ptr& execution, + std::string target_id, std::string local_id, + ::google::protobuf::RpcChannel* channel) + : ExecuteBase{request, response, execution, std::move(target_id), + std::move(local_id)}, + channel_(channel) { + span_option.cntl = &cntl_; + span_option.is_client = true; + span_option.party_id = local_id; + span_option.service_id = exec_ctx_.ServiceId(); +} + +RemoteExecute::~RemoteExecute() { + if (executing_) { + Cancel(); + } +} + void RemoteExecute::Run() { if (executing_) { SPDLOG_ERROR("Run should only be called once."); @@ -144,4 +158,54 @@ void RemoteExecute::Run() { executing_ = true; } +void RemoteExecute::Cancel() { + if (!executing_) { + return; + } + + executing_ = false; + brpc::StartCancel(cntl_.call_id()); + + span_option.code = errors::ErrorCode::UNEXPECTED_ERROR; + span_option.msg = "remote execute task is canceled."; + SetSpanAttrs(span_, span_option); + span_->End(); +} + +void RemoteExecute::WaitToFinish() { + if (!executing_) { + return; + } + + span_option.code = errors::ErrorCode::OK; + span_option.msg = fmt::format("call ({}) from ({}) execute seccessfully", + exec_ctx_.TargetId(), exec_ctx_.LocalId()); + + brpc::Join(cntl_.call_id()); + + executing_ = false; + + if (cntl_.Failed()) { + span_option.msg = fmt::format("call ({}) from ({}) network error, msg:{}", + exec_ctx_.TargetId(), exec_ctx_.LocalId(), + cntl_.ErrorText()); + span_option.code = errors::ErrorCode::NETWORK_ERROR; + } else if (exec_ctx_.ResponseStatus().code() != errors::ErrorCode::OK) { + span_option.msg = fmt::format(fmt::format( + "call ({}) from ({}) execute failed: code({}), {}", + exec_ctx_.TargetId(), exec_ctx_.LocalId(), + exec_ctx_.ResponseStatus().code(), exec_ctx_.ResponseStatus().msg())); + span_option.code = errors::ErrorCode::NETWORK_ERROR; + } + + SetSpanAttrs(span_, span_option); + span_->End(); + + if (span_option.code == errors::ErrorCode::OK) { + exec_ctx_.MergeResonseHeader(); + } else { + SERVING_THROW(span_option.code, "{}", span_option.msg); + } +} + } // namespace secretflow::serving diff --git a/secretflow_serving/framework/execute_context.h b/secretflow_serving/framework/execute_context.h index 8150670..b0ef0ee 100644 --- a/secretflow_serving/framework/execute_context.h +++ b/secretflow_serving/framework/execute_context.h @@ -27,8 +27,7 @@ namespace secretflow::serving { void ExeResponseToIoMap( apis::ExecuteResponse& exec_res, - std::unordered_map>* - node_io_map); + std::unordered_map* node_io_map); class ExecuteContext { public: @@ -40,17 +39,16 @@ class ExecuteContext { template < typename T, typename = std::enable_if_t, - std::unordered_map>>>> + std::decay_t, std::unordered_map>>> void SetEntryNodesInputs(T&& node_io_map) { if (node_io_map.empty()) { return; } - auto task = exec_req_.mutable_task(); + auto* task = exec_req_.mutable_task(); task->set_execution_id(execution_->id()); auto entry_nodes = execution_->GetEntryNodes(); for (const auto& n : entry_nodes) { - auto entry_node_io = task->add_nodes(); + auto* entry_node_io = task->add_nodes(); entry_node_io->set_name(n->GetName()); for (const auto& e : n->in_edges()) { auto iter = node_io_map.find(e->src_node()); @@ -58,7 +56,7 @@ class ExecuteContext { errors::ErrorCode::LOGIC_ERROR, "Input of {} cannot be found in ctx(size:{})", e->src_node(), node_io_map.size()); - for (auto& io : *(iter->second->mutable_ios())) { + for (auto& io : *(iter->second.mutable_ios())) { if constexpr (std::is_lvalue_reference_v) { *(entry_node_io->mutable_ios()->Add()) = io; } else { @@ -69,13 +67,11 @@ class ExecuteContext { } } - void Execute(std::shared_ptr<::google::protobuf::RpcChannel> channel, - brpc::Controller* cntl); - void Execute(std::shared_ptr execution_core); + void Execute(::google::protobuf::RpcChannel* channel, brpc::Controller* cntl); + void Execute(std::shared_ptr& execution_core); void GetResultNodeIo( - std::unordered_map>* - node_io_map); + std::unordered_map* node_io_map); void CheckAndUpdateResponse(const apis::ExecuteResponse& exec_res); void CheckAndUpdateResponse(); @@ -118,18 +114,14 @@ class ExecuteBase { std::move(local_id)} {} virtual ~ExecuteBase() = default; - void SetInputs(std::unordered_map>& - node_io_map) { + void SetInputs(std::unordered_map& node_io_map) { exec_ctx_.SetEntryNodesInputs(node_io_map); } - void SetInputs( - std::unordered_map>&& - node_io_map) { + void SetInputs(std::unordered_map&& node_io_map) { exec_ctx_.SetEntryNodesInputs(std::move(node_io_map)); } virtual void GetOutputs( - std::unordered_map>* - node_io_map) { + std::unordered_map* node_io_map) { exec_ctx_.GetResultNodeIo(node_io_map); } @@ -139,84 +131,28 @@ class ExecuteBase { ExecuteContext exec_ctx_; }; -class RemoteExecute : public ExecuteBase, - public std::enable_shared_from_this { +class RemoteExecute : public ExecuteBase { public: RemoteExecute(const apis::PredictRequest* request, apis::PredictResponse* response, const std::shared_ptr& execution, std::string target_id, std::string local_id, - std::shared_ptr<::google::protobuf::RpcChannel> channel) - : ExecuteBase{request, response, execution, std::move(target_id), - std::move(local_id)}, - channel_(std::move(channel)) { - span_option.cntl = &cntl_; - span_option.is_client = true; - span_option.party_id = local_id; - span_option.service_id = exec_ctx_.ServiceId(); - } + ::google::protobuf::RpcChannel* channel); - virtual ~RemoteExecute() { - if (executing_) { - Cancel(); - } - } + virtual ~RemoteExecute(); - virtual void Run() override; - virtual void Cancel() { - if (!executing_) { - return; - } + void Run() override; - executing_ = false; - brpc::StartCancel(cntl_.call_id()); + virtual void Cancel(); - span_option.code = errors::ErrorCode::UNEXPECTED_ERROR; - span_option.msg = "remote execute task is canceled."; - SetSpanAttrs(span_, span_option); - span_->End(); - } - - virtual void WaitToFinish() { - if (!executing_) { - return; - } - - span_option.code = errors::ErrorCode::OK; - span_option.msg = fmt::format("call ({}) from ({}) execute seccessfully", - exec_ctx_.TargetId(), exec_ctx_.LocalId()); - - brpc::Join(cntl_.call_id()); - - executing_ = false; - - if (cntl_.Failed()) { - span_option.msg = fmt::format("call ({}) from ({}) network error, msg:{}", - exec_ctx_.TargetId(), exec_ctx_.LocalId(), - cntl_.ErrorText()); - span_option.code = errors::ErrorCode::NETWORK_ERROR; - } else if (exec_ctx_.ResponseStatus().code() != errors::ErrorCode::OK) { - span_option.msg = fmt::format(fmt::format( - "call ({}) from ({}) execute failed: code({}), {}", - exec_ctx_.TargetId(), exec_ctx_.LocalId(), - exec_ctx_.ResponseStatus().code(), exec_ctx_.ResponseStatus().msg())); - span_option.code = errors::ErrorCode::NETWORK_ERROR; - } - - SetSpanAttrs(span_, span_option); - span_->End(); - - if (span_option.code == errors::ErrorCode::OK) { - exec_ctx_.MergeResonseHeader(); - } else { - SERVING_THROW(span_option.code, span_option.msg); - } - } + virtual void WaitToFinish(); protected: - std::shared_ptr<::google::protobuf::RpcChannel> channel_; + ::google::protobuf::RpcChannel* channel_; brpc::Controller cntl_; + bool executing_{false}; + opentelemetry::nostd::shared_ptr span_; SpanAttrOption span_option; }; diff --git a/secretflow_serving/framework/execute_context_test.cc b/secretflow_serving/framework/execute_context_test.cc index 2d578df..f52f0c6 100644 --- a/secretflow_serving/framework/execute_context_test.cc +++ b/secretflow_serving/framework/execute_context_test.cc @@ -242,20 +242,20 @@ TEST_F(ExecuteContextTest, BuildExecCtx) { { auto& exec_response = ctx_bob->ExeResponse(); exec_response.mutable_result()->set_execution_id(0); - auto node = exec_response.mutable_result()->add_nodes(); + auto* node = exec_response.mutable_result()->add_nodes(); node->set_name("mock_node_1"); - auto io = node->add_ios(); + auto* io = node->add_ios(); io->add_datas("mock_bob_data"); } { auto& exec_response = ctx_alice->ExeResponse(); exec_response.mutable_result()->set_execution_id(0); - auto node_1 = exec_response.mutable_result()->add_nodes(); + auto* node_1 = exec_response.mutable_result()->add_nodes(); node_1->set_name("mock_node_1"); node_1->add_ios()->add_datas("mock_alice_data"); } - std::unordered_map> node_io_map; + std::unordered_map node_io_map; ctx_bob->GetResultNodeIo(&node_io_map); ctx_alice->GetResultNodeIo(&node_io_map); diff --git a/secretflow_serving/framework/model_info_collector.cc b/secretflow_serving/framework/model_info_collector.cc index cf69067..5c60984 100644 --- a/secretflow_serving/framework/model_info_collector.cc +++ b/secretflow_serving/framework/model_info_collector.cc @@ -119,7 +119,7 @@ void ModelInfoCollector::DoCollect() { bool ModelInfoCollector::TryCollect( const std::string& remote_party_id, - const std::shared_ptr<::google::protobuf::RpcChannel>& channel) { + const std::unique_ptr<::google::protobuf::RpcChannel>& channel) { brpc::Controller cntl; apis::GetModelInfoResponse response; apis::GetModelInfoRequest request; diff --git a/secretflow_serving/framework/model_info_collector.h b/secretflow_serving/framework/model_info_collector.h index 70f8af3..6ef49e5 100644 --- a/secretflow_serving/framework/model_info_collector.h +++ b/secretflow_serving/framework/model_info_collector.h @@ -33,7 +33,7 @@ class ModelInfoCollector { std::shared_ptr model_bundle; std::shared_ptr< - std::map>> + std::map>> remote_channel_map; }; @@ -58,7 +58,7 @@ class ModelInfoCollector { private: bool TryCollect( const std::string& remote_party_id, - const std::shared_ptr<::google::protobuf::RpcChannel>& channel); + const std::unique_ptr<::google::protobuf::RpcChannel>& channel); void CheckAndSetSpecificMap(); diff --git a/secretflow_serving/framework/predictor.cc b/secretflow_serving/framework/predictor.cc index 1e96730..9269b71 100644 --- a/secretflow_serving/framework/predictor.cc +++ b/secretflow_serving/framework/predictor.cc @@ -28,52 +28,52 @@ Predictor::Predictor(Options opts) : opts_(std::move(opts)) {} void Predictor::Predict(const apis::PredictRequest* request, apis::PredictResponse* response) { - std::unordered_map> - prev_node_io_map; - std::vector> async_running_execs; + std::unordered_map prev_node_io_map; + std::vector> async_running_execs; async_running_execs.reserve(opts_.channels->size()); auto execute_locally = [&](const std::shared_ptr& execution, - std::unordered_map>& - prev_io_map, - std::unordered_map>& - cur_io_map) { + std::unordered_map&& prev_io_map, + std::unordered_map* cur_io_map) { // exec locally auto local_exec = BuildLocalExecute(request, response, execution); local_exec->SetInputs(std::move(prev_io_map)); local_exec->Run(); - local_exec->GetOutputs(&cur_io_map); + local_exec->GetOutputs(cur_io_map); }; for (const auto& e : opts_.executions) { async_running_execs.clear(); - std::unordered_map> - new_node_io_map; + std::unordered_map new_node_io_map; if (e->GetDispatchType() == DispatchType::DP_ALL) { + // remote exec for (const auto& [party_id, channel] : *opts_.channels) { auto ctx = BuildRemoteExecute(request, response, e, party_id, channel); - ctx->SetInputs(prev_node_io_map); + ctx->SetInputs(std::move(prev_node_io_map)); ctx->Run(); - async_running_execs.emplace_back(ctx); + + async_running_execs.emplace_back(std::move(ctx)); } // exec locally if (execution_core_) { - execute_locally(e, prev_node_io_map, new_node_io_map); - for (auto& exec : async_running_execs) { - exec->WaitToFinish(); - exec->GetOutputs(&new_node_io_map); - } + execute_locally(e, std::move(prev_node_io_map), &new_node_io_map); } else { // TODO: support no execution core scene SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented"); } + // join async exec + for (const auto& exec : async_running_execs) { + exec->WaitToFinish(); + exec->GetOutputs(&new_node_io_map); + } + } else if (e->GetDispatchType() == DispatchType::DP_ANYONE) { // exec locally if (execution_core_) { - execute_locally(e, prev_node_io_map, new_node_io_map); + execute_locally(e, std::move(prev_node_io_map), &new_node_io_map); } else { // TODO: support no execution core scene SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented"); @@ -81,7 +81,7 @@ void Predictor::Predict(const apis::PredictRequest* request, } else if (e->GetDispatchType() == DispatchType::DP_SPECIFIED) { if (e->SpecificToThis()) { SERVING_ENFORCE(execution_core_, errors::ErrorCode::UNEXPECTED_ERROR); - execute_locally(e, prev_node_io_map, new_node_io_map); + execute_locally(e, std::move(prev_node_io_map), &new_node_io_map); } else { auto iter = opts_.specific_party_map.find(e->id()); SERVING_ENFORCE(iter != opts_.specific_party_map.end(), @@ -89,7 +89,7 @@ void Predictor::Predict(const apis::PredictRequest* request, "{} execution assign to no party", e->id()); auto ctx = BuildRemoteExecute(request, response, e, iter->second, opts_.channels->at(iter->second)); - ctx->SetInputs(prev_node_io_map); + ctx->SetInputs(std::move(prev_node_io_map)); ctx->Run(); ctx->WaitToFinish(); ctx->GetOutputs(&new_node_io_map); @@ -105,29 +105,30 @@ void Predictor::Predict(const apis::PredictRequest* request, DealFinalResult(prev_node_io_map, response); } -std::shared_ptr Predictor::BuildRemoteExecute( +std::unique_ptr Predictor::BuildRemoteExecute( const apis::PredictRequest* request, apis::PredictResponse* response, const std::shared_ptr& execution, std::string target_id, - std::shared_ptr<::google::protobuf::RpcChannel> channel) { - return std::make_shared(request, response, execution, - target_id, opts_.party_id, channel); + const std::unique_ptr<::google::protobuf::RpcChannel>& channel) { + return std::make_unique(request, response, execution, + std::move(target_id), opts_.party_id, + channel.get()); } -std::shared_ptr Predictor::BuildLocalExecute( +std::unique_ptr Predictor::BuildLocalExecute( const apis::PredictRequest* request, apis::PredictResponse* response, const std::shared_ptr& execution) { - return std::make_shared(request, response, execution, + return std::make_unique(request, response, execution, opts_.party_id, opts_.party_id, execution_core_); } void Predictor::DealFinalResult( - std::unordered_map>& node_io_map, + std::unordered_map& node_io_map, apis::PredictResponse* response) { SERVING_ENFORCE(node_io_map.size() == 1, errors::ErrorCode::LOGIC_ERROR); auto& node_io = node_io_map.begin()->second; - SERVING_ENFORCE(node_io->ios_size() == 1, errors::ErrorCode::LOGIC_ERROR); - auto& ios = node_io->ios(0); + SERVING_ENFORCE(node_io.ios_size() == 1, errors::ErrorCode::LOGIC_ERROR); + const auto& ios = node_io.ios(0); SERVING_ENFORCE(ios.datas_size() == 1, errors::ErrorCode::LOGIC_ERROR); std::shared_ptr record_batch = DeserializeRecordBatch(ios.datas(0)); diff --git a/secretflow_serving/framework/predictor.h b/secretflow_serving/framework/predictor.h index 64f7581..4f6f1d0 100644 --- a/secretflow_serving/framework/predictor.h +++ b/secretflow_serving/framework/predictor.h @@ -31,7 +31,7 @@ namespace secretflow::serving { // key: node_id // value: channel to the executor using PartyChannelMap = - std::map>; + std::map>; class Predictor { public: @@ -57,20 +57,19 @@ class Predictor { } protected: - virtual std::shared_ptr BuildRemoteExecute( + virtual std::unique_ptr BuildRemoteExecute( const apis::PredictRequest* request, apis::PredictResponse* response, const std::shared_ptr& execution, std::string target_id, - std::shared_ptr<::google::protobuf::RpcChannel> channel); + const std::unique_ptr<::google::protobuf::RpcChannel>& channel); - virtual std::shared_ptr BuildLocalExecute( + virtual std::unique_ptr BuildLocalExecute( const apis::PredictRequest* request, apis::PredictResponse* response, const std::shared_ptr& execution); void BuildExecCtx(); void DealFinalResult( - std::unordered_map>& - node_io_map, + std::unordered_map& node_io_map, apis::PredictResponse* response); protected: diff --git a/secretflow_serving/framework/predictor_test.cc b/secretflow_serving/framework/predictor_test.cc index 682add1..588d49c 100644 --- a/secretflow_serving/framework/predictor_test.cc +++ b/secretflow_serving/framework/predictor_test.cc @@ -94,8 +94,7 @@ class MockRemoteExecute : public RemoteExecute { exec_ctx_.CheckAndUpdateResponse(mock_exec_res); } void GetOutputs( - std::unordered_map>* - node_io_map) override { + std::unordered_map* node_io_map) override { ExeResponseToIoMap(mock_exec_res, node_io_map); } apis::ExecuteResponse mock_exec_res; @@ -103,13 +102,14 @@ class MockRemoteExecute : public RemoteExecute { class MockPredictor : public Predictor { public: - MockPredictor(const Options& options) : Predictor(options) {} - std::shared_ptr BuildRemoteExecute( + explicit MockPredictor(const Options& options) : Predictor(options) {} + + std::unique_ptr BuildRemoteExecute( const apis::PredictRequest* request, apis::PredictResponse* response, const std::shared_ptr& execution, std::string target_id, - std::shared_ptr<::google::protobuf::RpcChannel> channel) override { - auto exec = std::make_shared( - request, response, execution, target_id, opts_.party_id, channel); + const std::unique_ptr<::google::protobuf::RpcChannel>& channel) override { + auto exec = std::make_unique( + request, response, execution, target_id, opts_.party_id, channel.get()); exec->mock_exec_res = remote_exec_res_; return exec; } diff --git a/secretflow_serving/ops/arrow_processing.cc b/secretflow_serving/ops/arrow_processing.cc index 4d6d57e..8c084c5 100644 --- a/secretflow_serving/ops/arrow_processing.cc +++ b/secretflow_serving/ops/arrow_processing.cc @@ -338,9 +338,10 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts) arrow::compute::GetFunctionRegistry()->GetFunction(func.name()), arrow_func); - // Noticed, we only allowed scalar type arrow compute function + // Noticed, we only allowed scalar arrow compute function or cast function SERVING_ENFORCE( - arrow_func->kind() == arrow::compute::Function::Kind::SCALAR, + arrow_func->kind() == arrow::compute::Function::Kind::SCALAR || + func.name() == "cast", errors::ErrorCode::LOGIC_ERROR, "unsupported arrow compute func:{}", func.name()); diff --git a/secretflow_serving/ops/graph.cc b/secretflow_serving/ops/graph.cc index 4f96101..9a2e5e7 100644 --- a/secretflow_serving/ops/graph.cc +++ b/secretflow_serving/ops/graph.cc @@ -153,7 +153,7 @@ Graph::Graph(GraphDef graph_def) : def_(std::move(graph_def)) { // and execution_defs graph_view_.set_version(def_.version()); - for (auto& node : def_.node_list()) { + for (const auto& node : def_.node_list()) { NodeView view; *(view.mutable_name()) = node.name(); *(view.mutable_op()) = node.op(); diff --git a/secretflow_serving/ops/merge_y.cc b/secretflow_serving/ops/merge_y.cc index e71f004..dc7e2da 100644 --- a/secretflow_serving/ops/merge_y.cc +++ b/secretflow_serving/ops/merge_y.cc @@ -39,6 +39,9 @@ MergeY::MergeY(OpKernelOptions opts) : OpKernel(std::move(opts)) { output_col_name_ = GetNodeAttr(opts_.node_def, "output_col_name"); + exp_iters_ = GetNodeAttr(opts_.node_def, *opts_.op_def, "exp_iters"); + CheckLinkFuncAragsValid(link_function_, exp_iters_); + BuildInputSchema(); BuildOutputSchema(); } @@ -64,7 +67,8 @@ void MergeY::DoCompute(ComputeContext* ctx) { SERVING_CHECK_ARROW_STATUS(builder.Resize(merged_array->length())); for (int64_t i = 0; i < merged_array->length(); ++i) { auto score = - ApplyLinkFunc(merged_array->Value(i), link_function_) * yhat_scale_; + ApplyLinkFunc(merged_array->Value(i), link_function_, exp_iters_) * + yhat_scale_; SERVING_CHECK_ARROW_STATUS(builder.Append(score)); } std::shared_ptr res_array; @@ -87,7 +91,7 @@ void MergeY::BuildOutputSchema() { } REGISTER_OP_KERNEL(MERGE_Y, MergeY) -REGISTER_OP(MERGE_Y, "0.0.2", +REGISTER_OP(MERGE_Y, "0.0.3", "Merge all partial y(score) and apply link function") .Returnable() .Mergeable() @@ -102,7 +106,7 @@ REGISTER_OP(MERGE_Y, "0.0.2", "link_function", "Type of link function, defined in " "`secretflow_serving/protos/link_function.proto`. Optional value: " - "LF_EXP, " + "LF_EXP, LF_EXP_TAYLOR, " "LF_RECIPROCAL, " "LF_IDENTITY, LF_SIGMOID_RAW, LF_SIGMOID_MM1, LF_SIGMOID_MM3, " "LF_SIGMOID_GA, " @@ -114,6 +118,10 @@ REGISTER_OP(MERGE_Y, "0.0.2", .StringAttr("input_col_name", "The column name of partial_y", false, false) .StringAttr("output_col_name", "The column name of merged score", false, false) + .Int32Attr("exp_iters", + "Number of iterations of `exp` approximation, valid when " + "`link_function` set `LF_EXP_TAYLOR`", + false, true, 0) .Input("partial_ys", "The list of partial y, data type: `double`") .Output("scores", "The merge result of `partial_ys`, data type: `double`"); diff --git a/secretflow_serving/ops/merge_y.h b/secretflow_serving/ops/merge_y.h index 51f772f..f519c26 100644 --- a/secretflow_serving/ops/merge_y.h +++ b/secretflow_serving/ops/merge_y.h @@ -37,6 +37,8 @@ class MergeY : public OpKernel { LinkFunctionType link_function_; std::string input_col_name_; std::string output_col_name_; + + int32_t exp_iters_ = 0; }; } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/merge_y_test.cc b/secretflow_serving/ops/merge_y_test.cc index de4fb6e..22294e3 100644 --- a/secretflow_serving/ops/merge_y_test.cc +++ b/secretflow_serving/ops/merge_y_test.cc @@ -27,6 +27,7 @@ namespace secretflow::serving::op { struct Param { std::string link_func; double yhat_scale; + int32_t exp_iters = 0; }; class MergeYParamTest : public ::testing::TestWithParam { @@ -58,14 +59,17 @@ TEST_P(MergeYParamTest, Works) { { AttrValue link_func_value; link_func_value.set_s(param.link_func); - node_def.mutable_attr_values()->insert( - {"link_function", std::move(link_func_value)}); + node_def.mutable_attr_values()->insert({"link_function", link_func_value}); } { AttrValue scale_value; scale_value.set_d(param.yhat_scale); - node_def.mutable_attr_values()->insert( - {"yhat_scale", std::move(scale_value)}); + node_def.mutable_attr_values()->insert({"yhat_scale", scale_value}); + } + { + AttrValue exp_iters_value; + exp_iters_value.set_i32(param.exp_iters); + node_def.mutable_attr_values()->insert({"exp_iters", exp_iters_value}); } // mock input values @@ -74,10 +78,12 @@ TEST_P(MergeYParamTest, Works) { // expect result double expect_score_0 = - ApplyLinkFunc(0.1 + 0.1 + 0.1, ParseLinkFuncType(param.link_func)) * + ApplyLinkFunc(0.1 + 0.1 + 0.1, ParseLinkFuncType(param.link_func), + param.exp_iters) * param.yhat_scale; double expect_score_1 = - ApplyLinkFunc(0.11 + 0.12 + 0.13, ParseLinkFuncType(param.link_func)) * + ApplyLinkFunc(0.11 + 0.12 + 0.13, ParseLinkFuncType(param.link_func), + param.exp_iters) * param.yhat_scale; double epsilon = 1E-13; @@ -146,15 +152,16 @@ TEST_P(MergeYParamTest, Works) { INSTANTIATE_TEST_SUITE_P( MergeYParamTestSuite, MergeYParamTest, ::testing::Values( - Param{"LF_EXP", 1.0}, Param{"LF_RECIPROCAL", 1.1}, - Param{"LF_IDENTITY", 1.2}, Param{"LF_SIGMOID_RAW", 1.3}, - Param{"LF_SIGMOID_MM1", 1.4}, Param{"LF_SIGMOID_MM3", 1.5}, - Param{"LF_SIGMOID_GA", 1.6}, Param{"LF_SIGMOID_T1", 1.7}, - Param{"LF_SIGMOID_T3", 1.8}, Param{"LF_SIGMOID_T5", 1.9}, - Param{"LF_SIGMOID_T7", 1.01}, Param{"LF_SIGMOID_T9", 1.02}, - Param{"LF_SIGMOID_LS7", 1.03}, Param{"LF_SIGMOID_SEG3", 1.04}, - Param{"LF_SIGMOID_SEG5", 1.05}, Param{"LF_SIGMOID_DF", 1.06}, - Param{"LF_SIGMOID_SR", 1.07}, Param{"LF_SIGMOID_SEGLS", 1.08})); + Param{"LF_EXP", 1.0}, Param{"LF_EXP_TAYLOR", 1.0, 8}, + Param{"LF_RECIPROCAL", 1.1}, Param{"LF_IDENTITY", 1.2}, + Param{"LF_SIGMOID_RAW", 1.3}, Param{"LF_SIGMOID_MM1", 1.4}, + Param{"LF_SIGMOID_MM3", 1.5}, Param{"LF_SIGMOID_GA", 1.6}, + Param{"LF_SIGMOID_T1", 1.7}, Param{"LF_SIGMOID_T3", 1.8}, + Param{"LF_SIGMOID_T5", 1.9}, Param{"LF_SIGMOID_T7", 1.01}, + Param{"LF_SIGMOID_T9", 1.02}, Param{"LF_SIGMOID_LS7", 1.03}, + Param{"LF_SIGMOID_SEG3", 1.04}, Param{"LF_SIGMOID_SEG5", 1.05}, + Param{"LF_SIGMOID_DF", 1.06}, Param{"LF_SIGMOID_SR", 1.07}, + Param{"LF_SIGMOID_SEGLS", 1.08})); // TODO: exception case diff --git a/secretflow_serving/ops/tree_ensemble_predict.cc b/secretflow_serving/ops/tree_ensemble_predict.cc index be32b57..0ec1c15 100644 --- a/secretflow_serving/ops/tree_ensemble_predict.cc +++ b/secretflow_serving/ops/tree_ensemble_predict.cc @@ -40,6 +40,9 @@ TreeEnsemblePredict::TreeEnsemblePredict(OpKernelOptions opts) GetNodeAttr(opts_.node_def, *opts_.op_def, "algo_func"); func_type_ = ParseLinkFuncType(func_type_str); + base_score_ = + GetNodeAttr(opts_.node_def, *opts_.op_def, "base_score"); + BuildInputSchema(); BuildOutputSchema(); } @@ -65,7 +68,8 @@ void TreeEnsemblePredict::DoCompute(ComputeContext* ctx) { arrow::DoubleBuilder builder; SERVING_CHECK_ARROW_STATUS(builder.Resize(merged_array->length())); for (int64_t i = 0; i < merged_array->length(); ++i) { - auto score = ApplyLinkFunc(merged_array->Value(i), func_type_); + auto score = merged_array->Value(i) + base_score_; + score = ApplyLinkFunc(score, func_type_); SERVING_CHECK_ARROW_STATUS(builder.Append(score)); } std::shared_ptr res_array; @@ -90,7 +94,7 @@ void TreeEnsemblePredict::BuildOutputSchema() { } REGISTER_OP_KERNEL(TREE_ENSEMBLE_PREDICT, TreeEnsemblePredict) -REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.1", +REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.2", "Accept the weighted results from multiple trees (`TREE_SELECT` + " "`TREE_MERGE`), merge them, and obtain the final prediction result " "of the tree ensemble.") @@ -101,6 +105,8 @@ REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.1", .StringAttr("output_col_name", "The column name of tree ensemble predict score", false, false) .Int32Attr("num_trees", "The number of ensemble's tree", false, false) + .DoubleAttr("base_score", "The initial prediction score, global bias.", + false, true, 0.0) .StringAttr( "algo_func", "Optional value: " diff --git a/secretflow_serving/ops/tree_ensemble_predict.h b/secretflow_serving/ops/tree_ensemble_predict.h index 7d250f9..58f1b5a 100644 --- a/secretflow_serving/ops/tree_ensemble_predict.h +++ b/secretflow_serving/ops/tree_ensemble_predict.h @@ -37,6 +37,8 @@ class TreeEnsemblePredict : public OpKernel { int32_t num_trees_; LinkFunctionType func_type_; + + double base_score_ = 0.0; }; } // namespace secretflow::serving::op diff --git a/secretflow_serving/ops/tree_ensemble_predict_test.cc b/secretflow_serving/ops/tree_ensemble_predict_test.cc index aef5b60..1adf4bf 100644 --- a/secretflow_serving/ops/tree_ensemble_predict_test.cc +++ b/secretflow_serving/ops/tree_ensemble_predict_test.cc @@ -47,6 +47,9 @@ TEST_P(TreeEnsemblePredictParamTest, Works) { }, "output_col_name": { "s": "scores" + }, + "base_score": { + "d": 0.1 } } } @@ -96,6 +99,7 @@ TEST_P(TreeEnsemblePredictParamTest, Works) { for (size_t col = 1; col < param.tree_weights.size(); ++col) { score += param.tree_weights[col][row]; } + score += 0.1; SERVING_CHECK_ARROW_STATUS(expect_res_builder.Append( ApplyLinkFunc(score, ParseLinkFuncType(param.algo_func)))); } diff --git a/secretflow_serving/protos/link_function.proto b/secretflow_serving/protos/link_function.proto index 9865b5c..298db37 100644 --- a/secretflow_serving/protos/link_function.proto +++ b/secretflow_serving/protos/link_function.proto @@ -59,4 +59,7 @@ enum LinkFunctionType { LF_SIGMOID_SR = 24; // LS7 if |x| <= 5.87, else SR. LF_SIGMOID_SEGLS = 25; + + // Exp + LF_EXP_TAYLOR = 30; } diff --git a/secretflow_serving/source/http_source.h b/secretflow_serving/source/http_source.h index bbdcdc6..251330c 100644 --- a/secretflow_serving/source/http_source.h +++ b/secretflow_serving/source/http_source.h @@ -36,7 +36,7 @@ class HttpSource : public Source { protected: std::string endpoint_; - std::shared_ptr channel_; + std::unique_ptr channel_; }; } // namespace secretflow::serving diff --git a/secretflow_serving/util/arrow_helper.cc b/secretflow_serving/util/arrow_helper.cc index c5c9967..66a0d3e 100644 --- a/secretflow_serving/util/arrow_helper.cc +++ b/secretflow_serving/util/arrow_helper.cc @@ -416,8 +416,9 @@ std::shared_ptr ReadCsvFileToTable( return table; } -arrow::Datum GetRowsFilter(const std::shared_ptr id_column, - const std::vector& ids) { +arrow::Datum GetRowsFilter( + const std::shared_ptr& id_column, + const std::vector& ids) { arrow::StringBuilder builder; SERVING_CHECK_ARROW_STATUS(builder.AppendValues(ids)); std::shared_ptr query_data_array; @@ -442,8 +443,8 @@ arrow::Datum GetRowsFilter(const std::shared_ptr id_column, return filter; } -std::shared_ptr GetIdColumnFromFile(std::string filename, - std::string id_name) { +std::shared_ptr GetIdColumnFromFile( + const std::string& filename, const std::string& id_name) { std::vector> fields; fields.push_back(arrow::field(id_name, arrow::utf8())); auto schema = arrow::schema(fields); @@ -456,7 +457,7 @@ std::shared_ptr GetIdColumnFromFile(std::string filename, } std::shared_ptr ExtractRowsFromTable( - std::shared_ptr table, arrow::Datum filter) { + const std::shared_ptr& table, const arrow::Datum& filter) { arrow::Datum filtered_table; SERVING_GET_ARROW_RESULT(arrow::compute::Filter(table, filter), filtered_table); diff --git a/secretflow_serving/util/arrow_helper.h b/secretflow_serving/util/arrow_helper.h index 3c052e8..2868c3a 100644 --- a/secretflow_serving/util/arrow_helper.h +++ b/secretflow_serving/util/arrow_helper.h @@ -107,13 +107,14 @@ std::shared_ptr ReadCsvFileToTable( const std::string& path, const std::shared_ptr& feature_schema); -arrow::Datum GetRowsFilter(const std::shared_ptr id_column, - const std::vector& ids); +arrow::Datum GetRowsFilter( + const std::shared_ptr& id_column, + const std::vector& ids); std::shared_ptr ExtractRowsFromTable( - std::shared_ptr table, arrow::Datum filter); + const std::shared_ptr& table, const arrow::Datum& filter); -std::shared_ptr GetIdColumnFromFile(std::string filename, - std::string id_name); +std::shared_ptr GetIdColumnFromFile( + const std::string& filename, const std::string& id_name); } // namespace secretflow::serving diff --git a/secretflow_serving/util/network.cc b/secretflow_serving/util/network.cc index f344560..5e5f22f 100644 --- a/secretflow_serving/util/network.cc +++ b/secretflow_serving/util/network.cc @@ -37,10 +37,10 @@ std::string FillHttpPrefix(const std::string& addr, bool ssl_enabled) { } } // namespace -std::shared_ptr CreateBrpcChannel( +std::unique_ptr CreateBrpcChannel( const std::string& endpoint, bool enable_lb, const brpc::ChannelOptions& opts) { - auto channel = std::make_shared(); + auto channel = std::make_unique(); std::string remote_url = endpoint; std::string load_balancer; if (enable_lb) { @@ -57,7 +57,7 @@ std::shared_ptr CreateBrpcChannel( return channel; } -std::shared_ptr CreateBrpcChannel( +std::unique_ptr CreateBrpcChannel( const std::string& endpoint, const std::string& protocol, bool enable_lb, int32_t rpc_timeout_ms, int32_t connect_timeout_ms, const TlsConfig* tls_config) { diff --git a/secretflow_serving/util/network.h b/secretflow_serving/util/network.h index b33fd26..69fcf84 100644 --- a/secretflow_serving/util/network.h +++ b/secretflow_serving/util/network.h @@ -25,11 +25,11 @@ namespace secretflow::serving { -std::shared_ptr CreateBrpcChannel( +std::unique_ptr CreateBrpcChannel( const std::string& endpoint, bool enable_lb, const brpc::ChannelOptions& opts); -std::shared_ptr CreateBrpcChannel( +std::unique_ptr CreateBrpcChannel( const std::string& endpoint, const std::string& protocol, bool enable_lb, int32_t rpc_timeout_ms, int32_t connect_timeout_ms, const TlsConfig* tls_config); diff --git a/secretflow_serving/util/thread_pool.h b/secretflow_serving/util/thread_pool.h index 305219e..fca8e9a 100644 --- a/secretflow_serving/util/thread_pool.h +++ b/secretflow_serving/util/thread_pool.h @@ -122,7 +122,7 @@ class ThreadPool : public Singleton { } } - size_t GetTaskSize() const { + [[nodiscard]] size_t GetTaskSize() const { return std::accumulate( task_queues_.begin(), task_queues_.end(), 0, [](int size, auto& queue) { return size + queue.size(); });