Exploring Accuracy and Interpretability trade-off in Tabular Learning with Novel Attention-Based Models
Apart from the predictive performance, interpretability is essential for :
- uncovering hidden patterns in the data
- providing meaningful justification of decisions made by machine learning mode
- ...
In this concern, an important question arises: should one use inherently interpretable models or explain full-complexity models such as XGBoost, Random Forest with post hoc tools?
In this repository, we provide some concrete numerical results that can guide practitioners (or researchers) in their choice between using inherently interpretable solutions and explaining full-complexity models.
This study includes, TabSRAs, an attention based inherently interpretable model which is proving to be a viable option for (i) generating stable or robust explanations, and (ii) incorporating human knowledge during the training phase.
What is the actual performance gap between the full-complexity state-of-the-art models and their inherently interpretable counterparts in terms of accuracy?
Model | Rank (min) | Rank (max) | Rank (mean) | Rank (median) | Test score (mean) | Test score (median) | Test score (std) | Runing Time (mean) | Runing Time (meadian) |
DT | 2 | 12 | 10.476 | 11 | 0.868 | 0.907 | 0.163 | 0.294 | 0.032 |
EBM_S | 1 | 11 | 7.692 | 8 | 0.931 | 0.955 | 0.087 | 23.997 | 5.144 |
EBM | 1 | 10 | 5.477 | 5 | 0.959 | 0.982 | 0.067 | 97.837 | 19.737 |
LR | 7 | 12 | 11.701 | 12 | 0.760 | 0.839 | 0.232 | 21.124 | 19.716 |
TabSRALinear | 1 | 12 | 8.225 | 9 | 0.901 | 0.971 | 0.197 | 47.576 | 38.073 |
MLP | 1 | 12 | 6.992 | 8 | 0.924 | 0.973 | 0.159 | 24.165 | 19.256 |
ResNet | 1 | 12 | 7.120 | 8 | 0.909 | 0.975 | 0.195 | 95.123 | 53.212 |
SAINT | 1 | 12 | 5.625 | 6 | 0.946 | 0.982 | 0.093 | 216.053 | 126.841 |
FT-Transformer | 1 | 11 | 5.203 | 5 | 0.944 | 0.984 | 0.109 | 126.589 | 77.465 |
Random Forest | 1 | 10 | 4.214 | 4 | 0.985 | 0.992 | 0.021 | 39.030 | 8.252 |
XGBoost | 1 | 11 | 2.728 | 2 | 0.988 | 0.998 | 0.029 | 18.254 | 12.561 |
CatBoost | 1 | 10 | 2.545 | 2 | 0.991 | 0.999 | 0.021 | 12.176 | 4.025 |
Predictive performance of models across a benchmark of 45 datasets (59 tasks) introduced in the paper "Why do tree-based models still outperform deep learning on typical tabular data?". We report the rank over all tasks, the relative test score (Accuracy/
The considered inherently interpretable models are:
- Decision Trees (DT)
- Explainable Boosting Machine EBMs
- EBM: EBMs with pairwise interaction terms
- EBM_S: EBMs without pairwise interaction terms
- Linear/Logistic Regression (LR): pytorch is used for the implementation
- TabSRALinear: an instantiation of TabSRAs, which imitates the formulation of classical Linear models. More details or in the papers ESANN, ECML@XKDD
Among full-complexty models, we considered:
- MultiLayer Perceptron (MLP): pytorch is used for the implementation
- ResNet
- FT Transformer
- Random Forest
- XGBoost
What about the robustness of explanations, are the produced feature attributions similar for similar inputs?
Changes in feature attributions (the lower the better) using the CreditCardFraud dataset.
LR = Logistic Regression, SRA=TabSRALinear, XGB_SHAP=XGBoost+TreeSHAP
Create a new python environment, install the requirements
- Clone this repository of your machine
- Dowanload the random search results using the links:
- Copy and paste the downloaded files to
- Run the Notebook for reproducing results
NB: To use the notebook, you will need to install it in the python environment you have created using pip for example
- Use the Notebook for the example on the Credit Card Fraud dataset
- Use the Notebook for the example on the Heloc Fico dataset
Please follow the instructions here to benchmark a new model depending on your budget.
We use the skorch framework to make our implementation more scikit-learn friendly. Here is the old version.
import torch
import torch.nn as nn
from skorch.callbacks import EarlyStopping,LRScheduler,Checkpoint, TrainEndCheckpoint, EpochScoring, InputShapeSetterTabSRA
from skorch.dataset import Dataset
from skorch.helper import predefined_split
from sramodels.SRAModels import TabSRALinearClassifier
from sklearn.metrics import roc_auc_score
configs = {
"criterion": nn.BCEWithLogitsLoss,
scoring = EpochScoring(scoring='roc_auc',lower_is_better=False)#the scoring function
setter = InputShapeSetterTabSRA(regression=False)#used for setting the input and output dimension automatically
early_stop = EarlyStopping(monitor=scoring.scoring, patience=10,load_best=True,lower_is_better=False, threshold=0.0001,threshold_mode='abs')
callbacks = [scoring, setter, early_stop, lr_scheduler]
valid_dataset = Dataset(X_val.values.astype(np.float32),Y_val.astype(np.float32))# custom validation dataset
TabClassifier = TabSRALinearClassifier(**configs,train_split = predefined_split(valid_dataset),callbacks = callbacks)
_ = TabClassifier.fit(X_train_.values.astype(np.float32),Y_train_.astype(np.float32))
# prediction
Y_val_pred = TabClassifier.predict_proba(X_val.values.astype(np.float32))
best_aucroc = roc_auc_score(Y_val.astype(np.float32), Y_val_pred[:,1])
# feature attribution
attributions_val = TabClassifier.get_feature_attribution(X_val.values.astype(np.float32))
# attention weights
attentions_val = TabClassifier.get_attention(X_val.values.astype(np.float32))
Key parameters
the model parameters are preceded by module
: int (default=2) Number of SRA head/ensemble. Bigger values gives capacity to the model to produce less stable/robust explanations. Typical values are 1 or 2. -
: int (default=8) The attention head dimension ,$d_k$ in the paper. Typical values are {4,8,12}. -
: int (default=1) The number of hidden layers in in the Key/Query encoder. Typical values are {1,2}. -
: float (default=0.0) The neuron dropout rate used in the Key/Query encorder during the training. -
: bool (default=True) Whether to use bias term in the downstream linear classifier. -
: (default=torch.optim.Adam) -
: float (default=0.05) learning rate used for the training. -
: int (default=100) Maximal number of training iterations. -
: int (default=256)
TabSRA package with sklearn interface
This work has been done in collaboration between BPCE Group, Laboratoire d'Informatique de Paris Nord (LIPN UMR 7030), DAVID Lab UVSQ-Université Paris Saclay and was supported by the program Convention Industrielle de Formation par la Recherche (CIFRE) of the Association Nationale de la Recherche et de la Technologie (ANRT).
If you find the code useful, please cite it by using the following BibTeX entry:
title={Exploring accuracy and interpretability trade-off in tabular learning with novel attention-based models},
author={Amekoe, Kodjo Mawuena and Azzag, Hanane and Dagdia, Zaineb Chelly and Lebbah, Mustapha and Jaffre, Gregoire},
journal={Neural Computing and Applications},
author = {Kodjo Mawuena Amekoe and
Mohamed Djallel Dilmi and
Hanene Azzag and
Zaineb Chelly Dagdia and
Mustapha Lebbah and
Gregoire Jaffre},
title = {TabSRA: An Attention based Self-Explainable Model for Tabular Learning},
booktitle = {The31th European Symposium on Artificial Neural Networks, Computational Intelligence and Machine Learning (ESANN)},
year = {2023}
author = {Kodjo Mawuena Amekoe and
Hanene Azzag and
Mustapha Lebbah and
Zaineb Chelly Dagdia and
Gregoire Jaffre},
title = {A New Class of Intelligible Models for Tabular Learning},
booktitle = {In The 5th International Workshop on eXplainable Knowledge Discovery in Data Mining (PKDD)-ECML-PKDD},
year = {2023}