Skip to content

Commit

Permalink
Fixed missing feature names for XGBoost (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
akhvorov authored and krinart committed Jul 30, 2019
1 parent d788566 commit 41ab5dc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
5 changes: 3 additions & 2 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class XGBoostModelAssembler(BaseBoostingAssembler):
def __init__(self, model):
feature_names = model.get_booster().feature_names
self._feature_name_to_idx = {
name: idx for idx, name in enumerate(feature_names)
name: idx for idx, name in enumerate(feature_names or [])
}

model_dump = model.get_booster().get_dump(dump_format="json")
Expand All @@ -103,7 +103,8 @@ def _assemble_tree(self, tree):
return ast.NumVal(tree["leaf"])

threshold = ast.NumVal(tree["split_condition"])
feature_idx = self._feature_name_to_idx[tree["split"]]
split = tree["split"]
feature_idx = self._feature_name_to_idx.get(split, split)
feature_ref = ast.FeatureRef(feature_idx)

# Since comparison with NaN (missing) value always returns false we
Expand Down
40 changes: 40 additions & 0 deletions tests/assemblers/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import xgboost
import numpy as np
import os
from tests import utils
from m2cgen import assemblers, ast

Expand Down Expand Up @@ -228,3 +229,42 @@ def test_multi_class_best_ntree_limit():
])

assert utils.cmp_exprs(actual, expected)


def test_regression_saved_without_feature_names():
base_score = 0.6
estimator = xgboost.XGBRegressor(n_estimators=2, random_state=1,
max_depth=1, base_score=base_score)
utils.train_model_regression(estimator)

with utils.tmp_dir() as tmp_dirpath:
filename = os.path.join(tmp_dirpath, "tmp.file")
estimator.save_model(filename)
estimator = xgboost.XGBRegressor(base_score=base_score)
estimator.load_model(filename)

assembler = assemblers.XGBoostModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(base_score),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)

0 comments on commit 41ab5dc

Please sign in to comment.