Skip to content

Commit

Permalink
fix base score and msg format (#66)
Browse files Browse the repository at this point in the history
* fix err msg formta

* fix base score miss

* update version & changelog
  • Loading branch information
oeqqwq authored May 24, 2024
1 parent e6b7d74 commit c00e886
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

> please add your unreleased change here.
## 20240524 - 0.3.1b0

- [Bugfix] fix tree predict base score miss
- [Bugfix] fix http adapater error msg format failed

## 20240423 - 0.3.0b0

- [Feature] Add Trace function
Expand Down
2 changes: 1 addition & 1 deletion secretflow_serving/feature_adapter/http_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void HttpFeatureAdapter::OnFetchFeature(const Request& request,
SetSpanAttrs(span, span_option);

SERVING_ENFORCE(span_option.code == errors::ErrorCode::OK, span_option.code,
span_option.msg);
"{}", span_option.msg);
response->header->mutable_data()->swap(
*spi_response.mutable_header()->mutable_data());
response->features =
Expand Down
7 changes: 3 additions & 4 deletions secretflow_serving/framework/execute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ void ExecuteContext::CheckAndUpdateResponse() {
void ExecuteContext::CheckAndUpdateResponse(
const apis::ExecuteResponse& exec_res) {
if (!CheckStatusOk(exec_res.status())) {
SERVING_THROW(
exec_res.status().code(),
fmt::format("{} exec failed: code({}), {}", target_id_,
exec_res.status().code(), exec_res.status().msg()));
SERVING_THROW(exec_res.status().code(), "{} exec failed: code({}), {}",
target_id_, exec_res.status().code(),
exec_res.status().msg());
}
MergeResonseHeader(exec_res);
}
Expand Down
2 changes: 1 addition & 1 deletion secretflow_serving/framework/execute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class RemoteExecute : public ExecuteBase,
if (span_option.code == errors::ErrorCode::OK) {
exec_ctx_.MergeResonseHeader();
} else {
SERVING_THROW(span_option.code, span_option.msg);
SERVING_THROW(span_option.code, "{}", span_option.msg);
}
}

Expand Down
10 changes: 8 additions & 2 deletions secretflow_serving/ops/tree_ensemble_predict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ TreeEnsemblePredict::TreeEnsemblePredict(OpKernelOptions opts)
GetNodeAttr<std::string>(opts_.node_def, *opts_.op_def, "algo_func");
func_type_ = ParseLinkFuncType(func_type_str);

base_score_ =
GetNodeAttr<double>(opts_.node_def, *opts_.op_def, "base_score");

BuildInputSchema();
BuildOutputSchema();
}
Expand All @@ -65,7 +68,8 @@ void TreeEnsemblePredict::DoCompute(ComputeContext* ctx) {
arrow::DoubleBuilder builder;
SERVING_CHECK_ARROW_STATUS(builder.Resize(merged_array->length()));
for (int64_t i = 0; i < merged_array->length(); ++i) {
auto score = ApplyLinkFunc(merged_array->Value(i), func_type_);
auto score = merged_array->Value(i) + base_score_;
score = ApplyLinkFunc(score, func_type_);
SERVING_CHECK_ARROW_STATUS(builder.Append(score));
}
std::shared_ptr<arrow::Array> res_array;
Expand All @@ -90,7 +94,7 @@ void TreeEnsemblePredict::BuildOutputSchema() {
}

REGISTER_OP_KERNEL(TREE_ENSEMBLE_PREDICT, TreeEnsemblePredict)
REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.1",
REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.2",
"Accept the weighted results from multiple trees (`TREE_SELECT` + "
"`TREE_MERGE`), merge them, and obtain the final prediction result "
"of the tree ensemble.")
Expand All @@ -101,6 +105,8 @@ REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.1",
.StringAttr("output_col_name",
"The column name of tree ensemble predict score", false, false)
.Int32Attr("num_trees", "The number of ensemble's tree", false, false)
.DoubleAttr("base_score", "The initial prediction score, global bias.",
false, true, 0.0)
.StringAttr(
"algo_func",
"Optional value: "
Expand Down
2 changes: 2 additions & 0 deletions secretflow_serving/ops/tree_ensemble_predict.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class TreeEnsemblePredict : public OpKernel {

int32_t num_trees_;
LinkFunctionType func_type_;

double base_score_ = 0.0;
};

} // namespace secretflow::serving::op
4 changes: 4 additions & 0 deletions secretflow_serving/ops/tree_ensemble_predict_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ TEST_P(TreeEnsemblePredictParamTest, Works) {
},
"output_col_name": {
"s": "scores"
},
"base_score": {
"d": 0.1
}
}
}
Expand Down Expand Up @@ -96,6 +99,7 @@ TEST_P(TreeEnsemblePredictParamTest, Works) {
for (size_t col = 1; col < param.tree_weights.size(); ++col) {
score += param.tree_weights[col][row];
}
score += 0.1;
SERVING_CHECK_ARROW_STATUS(expect_res_builder.Append(
ApplyLinkFunc(score, ParseLinkFuncType(param.algo_func))));
}
Expand Down
2 changes: 1 addition & 1 deletion secretflow_serving_lib/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


__version__ = "0.3.0b0"
__version__ = "0.3.1b0"
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


__version__ = "0.3.0b0"
__version__ = "0.3.1b0"

0 comments on commit c00e886

Please sign in to comment.