Skip to content

Commit 31ae566

Browse files
committed
fix shape calcultator
Signed-off-by: Xavier Dupre <[email protected]>
1 parent 390dad7 commit 31ae566

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/xgboost/test_xgboost_issues.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
class TestXGBoostIssues(unittest.TestCase):
77
def test_issue_676(self):
8+
import json
89
import onnxruntime
910
import xgboost
1011
import numpy as np
@@ -15,13 +16,15 @@ def test_issue_676(self):
1516
convert_xgboost,
1617
)
1718

18-
def frozen_shape_calculator(operator):
19-
operator.outputs[0].type.shape = [None, 2]
19+
def xgbregressor_shape_calculator(operator):
20+
config = json.loads(operator.raw_operator.get_booster().save_config())
21+
n_targets = int(config["learner"]["learner_model_param"]["num_target"])
22+
operator.outputs[0].type.shape = [None, n_targets]
2023

2124
update_registered_converter(
2225
xgboost.XGBRegressor,
2326
"XGBoostXGBRegressor",
24-
frozen_shape_calculator,
27+
xgbregressor_shape_calculator,
2528
convert_xgboost,
2629
)
2730
# Your data and labels

0 commit comments

Comments
 (0)