File tree 1 file changed +6
-3
lines changed
1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change 5
5
6
6
class TestXGBoostIssues (unittest .TestCase ):
7
7
def test_issue_676 (self ):
8
+ import json
8
9
import onnxruntime
9
10
import xgboost
10
11
import numpy as np
@@ -15,13 +16,15 @@ def test_issue_676(self):
15
16
convert_xgboost ,
16
17
)
17
18
18
- def frozen_shape_calculator (operator ):
19
- operator .outputs [0 ].type .shape = [None , 2 ]
19
+ def xgbregressor_shape_calculator (operator ):
20
+ config = json .loads (operator .raw_operator .get_booster ().save_config ())
21
+ n_targets = int (config ["learner" ]["learner_model_param" ]["num_target" ])
22
+ operator .outputs [0 ].type .shape = [None , n_targets ]
20
23
21
24
update_registered_converter (
22
25
xgboost .XGBRegressor ,
23
26
"XGBoostXGBRegressor" ,
24
- frozen_shape_calculator ,
27
+ xgbregressor_shape_calculator ,
25
28
convert_xgboost ,
26
29
)
27
30
# Your data and labels
You can’t perform that action at this time.
0 commit comments