Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/gfdl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ class GFDLClassifier(ClassifierMixin, GFDL):
conditioning during Ridge regression. When set to zero or `None`,
model uses direct solve using Moore-Penrose Pseudo-Inverse.

rtol : float, default=None
Cutoff for small singular values for the Moore-Penrose
pseudo-inverse. Only applies when ``reg_alpha=None``.
When ``rtol=None``, the array API standard default for
``pinv`` is used.

Attributes
----------
n_features_in_ : int
Expand Down Expand Up @@ -253,14 +259,16 @@ def __init__(
weight_scheme: str = "uniform",
direct_links: bool = True,
seed: int = None,
reg_alpha: float = None
reg_alpha: float = None,
rtol: float = None
):
super().__init__(hidden_layer_sizes=hidden_layer_sizes,
activation=activation,
weight_scheme=weight_scheme,
direct_links=direct_links,
seed=seed,
reg_alpha=reg_alpha)
reg_alpha=reg_alpha,
rtol=rtol)

def fit(self, X, y):
"""
Expand Down
42 changes: 41 additions & 1 deletion src/gfdl/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from sklearn.datasets import make_classification
from sklearn.datasets import load_digits, make_classification
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
Expand Down Expand Up @@ -410,3 +410,43 @@ def test_classification_against_grafo(hidden_layer_sizes, n_classes, activation,
@parametrize_with_checks([GFDLClassifier(), EnsembleGFDLClassifier()])
def test_sklearn_api_conformance(estimator, check):
check(estimator)


@pytest.mark.parametrize("reg_alpha, rtol, expected_acc, expected_roc", [
(0.1, 1e-15, 0.9083333333333333, 0.9893414717354735),
(None, 1e-15, 0.2222222222222222, 0.5518850599798965),
(None, 1e-3, 0.8972222222222223, 0.9802912857599967),
])
def test_rtol_classifier(reg_alpha, rtol, expected_acc, expected_roc):
# For Moore-Penrose, a large singular value cutoff (rtol)
# may be required to achieve reasonable results. This test
# showcases that a default low cut of leads to almost random classification
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# showcases that a default low cut of leads to almost random classification
# showcases that a default low cut off leads to almost random classification

# output for the Digits datasets which is alleviated by increasing the cut off.
# This cut off has no effect on ridge solver.
data = load_digits()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=0)

scaler = StandardScaler().fit(X_train)
X_train_s = scaler.transform(X_train)
X_test_s = scaler.transform(X_test)

activation = "softmax"
weight_scheme = "normal"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well just specify the arguments directly in the call to the estimator below, since they're not used for anything else.

Since that's minor, I'll let you do that in a follow-up.

model = GFDLClassifier(hidden_layer_sizes=[800] * 10,
activation=activation,
weight_scheme=weight_scheme,
seed=0,
reg_alpha=reg_alpha,
rtol=rtol)
model.fit(X_train_s, y_train)

y_hat_cur = model.predict(X_test_s)
y_hat_cur_proba = model.predict_proba(X_test_s)

acc_cur = accuracy_score(y_test, y_hat_cur)
roc_cur = roc_auc_score(y_test, y_hat_cur_proba, multi_class="ovo")

np.testing.assert_allclose(acc_cur, expected_acc)
np.testing.assert_allclose(roc_cur, expected_roc)