Skip to content

Commit

Permalink
add feature bagging
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbonet committed Feb 24, 2024
1 parent 0b74f04 commit c7c8236
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ print(f"Accuracy: {accuracy * 100:.2f}%")
* The current default hyperparameters provide medium speed-accuracy performance.
* :rocket: For the fastest inference (but less accurate) set ``n_ensemble=1`` and ``optimization=None``.
* :bar_chart: If you are dealing with an **imbalanced dataset**, consider setting ``stratify_sampling=True`` with ``n_ensemble`` > 1.
* :globe_with_meridians: If you are dealing with a very **high-dimensional dataset** (e.g., >3000 features), consider setting ``feature_bagging=True`` with ``n_ensemble`` > 1.
* :ok_hand: For slower but most accurate predictions, optimize the inference parameters of HyperFast for each dataset. In this case, we recommend the following search space:

```python
Expand Down
27 changes: 24 additions & 3 deletions hyperfast/hyperfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class HyperFastClassifier(BaseEstimator, ClassifierMixin):
torch_pca (bool): Whether to use PyTorch-based PCA optimized for GPU (fast) or scikit-learn PCA (slower).
seed (int): Random seed for reproducibility.
custom_path (str or None): If str, this custom path will be used to load the Hyperfast model instead of the default path.
stratify_sampling (bool): Determines whether to use stratified sampling for creating the batch.
stratify_sampling (bool): Determines whether to use stratified sampling for creating the batch.
feature_bagging (bool): Indicates whether feature bagging should be performed when ensembling.
feature_bagging_size (int): Size of the feature subset when performing feature bagging.
cat_features (list or None): List of indices of categorical features.
"""

Expand All @@ -62,6 +64,8 @@ def __init__(
seed: int = 3,
custom_path: str | None = None,
stratify_sampling: bool = False,
feature_bagging: bool = False,
feature_bagging_size: int = 3000,
cat_features: List[int] | None = None,
) -> None:
self.device = device
Expand All @@ -74,6 +78,8 @@ def __init__(
self.seed = seed
self.custom_path = custom_path
self.stratify_sampling = stratify_sampling
self.feature_bagging = feature_bagging
self.feature_bagging_size = feature_bagging_size
self.cat_features = cat_features

seed_everything(self.seed)
Expand Down Expand Up @@ -242,8 +248,17 @@ def _initialize_fit_attributes(self) -> None:
self._main_networks = []
self._X_preds = []
self._y_preds = []
if self.feature_bagging:
self.selected_features = []

def _sample_data(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
if self.feature_bagging:
print("Performing feature bagging")
stds = torch.std(X, dim=0)
feature_idxs = torch.multinomial(stds, self.feature_bagging_size, replacement=False)
self.selected_features.append(feature_idxs)
X = X[:, feature_idxs]

if self.stratify_sampling:
# Stratified sampling
print("Using stratified sampling")
Expand Down Expand Up @@ -356,9 +371,15 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
pca = self._pcas[jj]
X_pred = self._X_preds[jj]
y_pred = self._y_preds[jj]
if self.feature_bagging:
X_ = X[:, self.selected_features[jj]]
orig_X_ = orig_X[:, self.selected_features[jj]]
else:
X_ = X
orig_X_ = orig_X

X_transformed = transform_data_for_main_network(
X=X, cfg=self._cfg, rf=rf, pca=pca
X=X_, cfg=self._cfg, rf=rf, pca=pca
)
outputs, intermediate_activations = forward_main_network(
X_transformed, main_network
Expand All @@ -374,7 +395,7 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
for bb, bias in enumerate(self._model.nn_bias):
if bb == 0:
outputs = nn_bias_logits(
outputs, orig_X, X_pred, y_pred, bias, self.n_classes_
outputs, orig_X_, X_pred, y_pred, bias, self.n_classes_
)
elif bb == 1:
outputs = nn_bias_logits(
Expand Down

0 comments on commit c7c8236

Please sign in to comment.