Skip to content

Commit 572c586

Browse files
Merge pull request #33 from salesforce/fix_mace_refine
Fix a bug in refine.py
2 parents ef432ed + 622a5f9 commit 572c586

File tree

1 file changed

+3
-4
lines changed
  • omnixai/explainers/tabular/counterfactual/mace

1 file changed

+3
-4
lines changed

omnixai/explainers/tabular/counterfactual/mace/refine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
66
#
7-
import numpy as np
87
import pandas as pd
98
from typing import Dict, Callable, Union
109

@@ -45,7 +44,7 @@ def _refine(
4544

4645
for col, (a, b) in cont_features.items():
4746
gap, r = b - a, None
48-
while (b - a) / (gap + 1e-3) > 0.1:
47+
while (b - a) / gap > 0.1:
4948
z = (a + b) * 0.5
5049
y.iloc[0, column2loc[col]] = z
5150
scores = predict_function(Tabular(data=y, categorical_columns=instance.categorical_columns))[0]
@@ -83,8 +82,8 @@ def refine(
8382
cont_features = {}
8483
for col in self.cont_columns:
8584
a, b = float(x[col].values[0]), float(y[col].values[0])
86-
if a != b:
87-
cont_features[col] = (a, b) if a <= b else (b, a)
85+
if abs(a - b) > 1e-6:
86+
cont_features[col] = (a, b)
8887
if len(cont_features) == 0:
8988
results.append(y)
9089
else:

0 commit comments

Comments
 (0)