Skip to content

Commit 334e625

Browse files
authored
feat: add conformal quantile prediction (#25)
1 parent eabe5a8 commit 334e625

6 files changed

+820
-201
lines changed

README.md

+109-75
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,22 @@
55
Neo LS-SVM is a modern [Least-Squares Support Vector Machine](https://en.wikipedia.org/wiki/Least-squares_support_vector_machine) implementation in Python that offers several benefits over sklearn's classic `sklearn.svm.SVC` classifier and `sklearn.svm.SVR` regressor:
66

77
1. ⚡ Linear complexity in the number of training examples with [Orthogonal Random Features](https://arxiv.org/abs/1610.09072).
8-
2. 🚀 Hyperparameter free: zero-cost optimization of the regularisation parameter γ and kernel parameter σ.
8+
2. 🚀 Hyperparameter free: zero-cost optimization of the [regularisation parameter γ](https://en.wikipedia.org/wiki/Ridge_regression#Tikhonov_regularization) and [kernel parameter σ](https://en.wikipedia.org/wiki/Radial_basis_function_kernel).
99
3. 🏔️ Adds a new tertiary objective that minimizes the complexity of the prediction surface.
1010
4. 🎁 Returns the leave-one-out residuals and error for free after fitting.
1111
5. 🌀 Learns an affine transformation of the feature matrix to optimally separate the target's bins.
1212
6. 🪞 Can solve the LS-SVM both in the primal and dual space.
13-
7. 🌡️ Isotonically calibrated `predict_proba` based on the leave-one-out predictions.
14-
8. 🎲 Asymmetric conformal Bayesian confidence intervals for classification and regression.
13+
7. 🌡️ Isotonically calibrated `predict_proba`.
14+
8. ✅ Conformally calibrated `predict_quantiles` and `predict_interval`.
15+
9. 🔔 Bayesian estimation of the predictive standard deviation with `predict_std`.
16+
10. 🐼 Pandas DataFrame output when the input is a pandas DataFrame.
1517

1618
## Using
1719

1820
### Installing
1921

2022
First, install this package with:
23+
2124
```bash
2225
pip install neo-ls-svm
2326
```
@@ -45,53 +48,61 @@ model = NeoLSSVM().fit(X_train, y_train)
4548
model.score(X_test, y_test) # 82.4% (compared to sklearn.svm.SVR's -11.8%)
4649
```
4750

48-
### Confidence intervals
51+
### Predicting quantiles
4952

50-
Neo LS-SVM implements conformal prediction with a Bayesian nonconformity estimate to compute confidence intervals for both classification and regression. Example usage:
53+
Neo LS-SVM implements conformal prediction with a Bayesian nonconformity estimate to compute quantiles and prediction intervals for both classification and regression. Example usage:
5154

5255
```python
53-
from neo_ls_svm import NeoLSSVM
54-
from pandas import get_dummies
55-
from sklearn.datasets import fetch_openml
56-
from sklearn.model_selection import train_test_split
57-
58-
# Load a regression problem and split in train and test.
59-
X, y = fetch_openml("ames_housing", version=1, return_X_y=True, as_frame=True, parser="auto")
60-
X_train, X_test, y_train, y_test = train_test_split(get_dummies(X), y, test_size=50, random_state=42)
56+
# Predict the house prices and their quantiles.
57+
ŷ_test = model.predict(X_test)
58+
ŷ_test_quantiles = model.predict_quantiles(X_test, quantiles=(0.025, 0.05, 0.1, 0.9, 0.95, 0.975))
59+
```
6160

62-
# Fit a Neo LS-SVM model.
63-
model = NeoLSSVM().fit(X_train, y_train)
61+
When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_quantiles` yields:
6462

65-
# Predict the house prices and confidence intervals on the test set.
66-
ŷ = model.predict(X_test)
67-
ŷ_conf = model.predict_proba(X_test, confidence_interval=True, confidence_level=0.95)
68-
# ŷ_conf[:, 0] and ŷ_conf[:, 1] are the lower and upper bound of the confidence interval for the predictions ŷ, respectively
69-
```
63+
| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
64+
|-----------:|---------:|---------:|---------:|---------:|---------:|---------:|
65+
| 1357 | 114283.0 | 124767.6 | 133314.0 | 203162.0 | 220407.5 | 245655.3 |
66+
| 2367 | 85518.3 | 91787.2 | 93709.8 | 107464.3 | 108472.6 | 114482.3 |
67+
| 2822 | 147165.9 | 157462.8 | 167193.1 | 243646.5 | 263324.4 | 291963.3 |
68+
| 2126 | 81788.7 | 88738.1 | 91367.4 | 111944.9 | 114800.7 | 122874.5 |
69+
| 1544 | 94507.1 | 108288.2 | 120184.3 | 222630.5 | 248668.2 | 283703.4 |
7070

71-
Let's visualize the confidence intervals on the test set:
71+
Let's visualize the predicted quantiles on the test set:
7272

73-
<img src="https://github.com/lsorber/neo-ls-svm/assets/4543654/472bf358-34d7-4a1a-8b5c-595fe65dbf77" width="512">
73+
<img src="https://github.com/lsorber/neo-ls-svm/assets/4543654/cd24e739-e857-4045-8a70-07e92367a901" width="512">
7474

7575
<details>
76-
<summary>Expand to see the code that generated the above graph.</summary>
76+
<summary>Expand to see the code that generated the graph above</summary>
7777

7878
```python
7979
import matplotlib.pyplot as plt
8080
import matplotlib.ticker as ticker
81-
import numpy as np
8281

83-
idx = np.argsort(-ŷ)
84-
y_ticks = np.arange(1, len(X_test) + 1)
82+
%config InlineBackend.figure_format = "retina"
83+
plt.rcParams["font.size"] = 8
84+
idx = (-ŷ_test.sample(50, random_state=42)).sort_values().index
85+
y_ticks = list(range(1, len(idx) + 1))
8586
plt.figure(figsize=(4, 5))
86-
plt.barh(y_ticks, ŷ_conf[idx, 1] - ŷ_conf[idx, 0], left=ŷ_conf[idx, 0], label="95% Confidence interval", color="lightblue")
87-
plt.plot(y_test.iloc[idx], y_ticks, "s", markersize=3, markerfacecolor="none", markeredgecolor="cornflowerblue", label="Actual value")
88-
plt.plot(ŷ[idx], y_ticks, "s", color="mediumblue", markersize=0.6, label="Predicted value")
87+
for j in range(3):
88+
end = ŷ_test_quantiles.shape[1] - 1 - j
89+
coverage = round(100 * (ŷ_test_quantiles.columns[end] - ŷ_test_quantiles.columns[j]))
90+
plt.barh(
91+
y_ticks,
92+
ŷ_test_quantiles.loc[idx].iloc[:, end] - ŷ_test_quantiles.loc[idx].iloc[:, j],
93+
left=ŷ_test_quantiles.loc[idx].iloc[:, j],
94+
label=f"{coverage}% Prediction interval",
95+
color=["#b3d9ff", "#86bfff", "#4da6ff"][j],
96+
)
97+
plt.plot(y_test.loc[idx], y_ticks, "s", markersize=3, markerfacecolor="none", markeredgecolor="#e74c3c", label="Actual value")
98+
plt.plot(ŷ_test.loc[idx], y_ticks, "s", color="blue", markersize=0.6, label="Predicted value")
8999
plt.xlabel("House price")
90100
plt.ylabel("Test house index")
101+
plt.xlim(0, 500e3)
91102
plt.yticks(y_ticks, y_ticks)
92103
plt.tick_params(axis="y", labelsize=6)
93104
plt.grid(axis="x", color="lightsteelblue", linestyle=":", linewidth=0.5)
94-
plt.gca().xaxis.set_major_formatter(ticker.StrMethodFormatter('${x:,.0f}'))
105+
plt.gca().xaxis.set_major_formatter(ticker.StrMethodFormatter("${x:,.0f}"))
95106
plt.gca().spines["top"].set_visible(False)
96107
plt.gca().spines["right"].set_visible(False)
97108
plt.legend()
@@ -100,6 +111,29 @@ plt.show()
100111
```
101112
</details>
102113

114+
### Predicting intervals
115+
116+
In addition to quantile prediction, you can use `predict_interval` to predict conformally calibrated prediction intervals. Compared to quantiles, these focus on reliable coverage over quantile accuracy. Example usage:
117+
118+
```python
119+
# Compute prediction intervals for the houses in the test set.
120+
ŷ_test_interval = model.predict_interval(X_test, coverage=0.95)
121+
122+
# Measure the coverage of the prediction intervals on the test set
123+
coverage = ((ŷ_test_interval.iloc[:, 0] <= y_test) & (y_test <= ŷ_test_interval.iloc[:, 1])).mean()
124+
print(coverage) # 94.3%
125+
```
126+
127+
When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_interval` yields:
128+
129+
| house_id | 0.025 | 0.975 |
130+
|-----------:|---------:|---------:|
131+
| 1357 | 114283.0 | 245849.2 |
132+
| 2367 | 85518.3 | 114411.4 |
133+
| 2822 | 147165.9 | 292179.2 |
134+
| 2126 | 81788.7 | 122838.1 |
135+
| 1544 | 94507.1 | 284062.6 |
136+
103137
## Benchmarks
104138

105139
We select all binary classification and regression datasets below 1M entries from the [AutoML Benchmark](https://arxiv.org/abs/2207.12560). Each dataset is split into 85% for training and 15% for testing. We apply `skrub.TableVectorizer` as a preprocessing step for `neo_ls_svm.NeoLSSVM` and `sklearn.svm.SVC,SVR` to vectorize the pandas DataFrame training data into a NumPy array. Models are fitted only once on each dataset, with their default settings and no hyperparameter tuning.
@@ -109,29 +143,29 @@ We select all binary classification and regression datasets below 1M entries fro
109143

110144
ROC-AUC on 15% test set:
111145

112-
| dataset | LGBMClassifier | NeoLSSVM | SVC |
113-
|---------------------------------:|-----------------:|----------------:|----------------:|
114-
| ada | 🥈 90.9% (0.1s) | 🥇 90.9% (0.8s) | 83.1% (1.0s) |
115-
| adult | 🥇 93.0% (0.5s) | 🥈 89.1% (6.0s) | / |
116-
| amazon_employee_access | 🥇 85.6% (0.5s) | 🥈 64.5% (2.8s) | / |
117-
| arcene | 🥈 78.0% (0.6s) | 70.0% (4.4s) | 🥇 82.0% (3.4s) |
118-
| australian | 🥇 88.3% (0.2s) | 79.9% (0.4s) | 🥈 81.9% (0.0s) |
119-
| bank-marketing | 🥇 93.5% (0.3s) | 🥈 91.0% (4.1s) | / |
120-
| blood-transfusion-service-center | 62.0% (0.1s) | 🥇 71.0% (0.5s) | 🥈 69.7% (0.0s) |
121-
| churn | 🥇 91.7% (0.4s) | 🥈 81.0% (0.8s) | 70.6% (0.8s) |
122-
| click_prediction_small | 🥇 67.7% (0.4s) | 🥈 66.6% (3.3s) | / |
123-
| jasmine | 🥇 86.1% (0.3s) | 79.5% (1.2s) | 🥈 85.3% (1.8s) |
124-
| kc1 | 🥇 78.9% (0.2s) | 🥈 76.6% (0.5s) | 45.7% (0.2s) |
125-
| kr-vs-kp | 🥇 100.0% (0.2s) | 99.2% (0.8s) | 🥈 99.4% (0.6s) |
126-
| madeline | 🥇 93.1% (0.4s) | 65.6% (0.8s) | 🥈 82.5% (4.5s) |
127-
| ozone-level-8hr | 🥈 91.2% (0.3s) | 🥇 91.6% (0.7s) | 72.8% (0.2s) |
128-
| pc4 | 🥇 95.3% (0.3s) | 🥈 90.9% (0.5s) | 25.7% (0.1s) |
129-
| phishingwebsites | 🥇 99.5% (0.3s) | 🥈 98.9% (1.3s) | 98.7% (2.6s) |
130-
| phoneme | 🥇 95.6% (0.2s) | 🥈 93.5% (0.8s) | 91.2% (0.7s) |
131-
| qsar-biodeg | 🥇 92.7% (0.2s) | 🥈 91.1% (1.2s) | 86.8% (0.1s) |
132-
| satellite | 🥈 98.7% (0.2s) | 🥇 99.5% (0.8s) | 98.5% (0.1s) |
133-
| sylvine | 🥇 98.5% (0.2s) | 🥈 97.1% (0.8s) | 96.5% (1.0s) |
134-
| wilt | 🥈 99.5% (0.2s) | 🥇 99.8% (0.9s) | 98.9% (0.2s) |
146+
| dataset | LGBMClassifier | NeoLSSVM | SVC |
147+
|---------------------------------:|-----------------:|-----------------:|-----------------:|
148+
| ada | 🥈 90.9% (0.1s) | 🥇 90.9% (1.9s) | 83.1% (4.5s) |
149+
| adult | 🥇 93.0% (0.5s) | 🥈 89.0% (15.7s) | / |
150+
| amazon_employee_access | 🥇 85.6% (0.5s) | 🥈 64.5% (9.0s) | / |
151+
| arcene | 🥈 78.0% (0.6s) | 70.0% (6.3s) | 🥇 82.0% (4.0s) |
152+
| australian | 🥇 88.3% (0.2s) | 79.9% (1.7s) | 🥈 81.9% (0.1s) |
153+
| bank-marketing | 🥇 93.5% (0.5s) | 🥈 91.0% (11.8s) | / |
154+
| blood-transfusion-service-center | 62.0% (0.3s) | 🥇 71.0% (2.2s) | 🥈 69.7% (0.1s) |
155+
| churn | 🥇 91.7% (0.6s) | 🥈 81.0% (2.1s) | 70.6% (2.9s) |
156+
| click_prediction_small | 🥇 67.7% (0.5s) | 🥈 66.6% (10.9s) | / |
157+
| jasmine | 🥇 86.1% (0.3s) | 79.5% (1.9s) | 🥈 85.3% (7.4s) |
158+
| kc1 | 🥇 78.9% (0.3s) | 🥈 76.6% (1.4s) | 45.7% (0.6s) |
159+
| kr-vs-kp | 🥇 100.0% (0.6s) | 99.2% (1.6s) | 🥈 99.4% (2.3s) |
160+
| madeline | 🥇 93.1% (0.5s) | 65.6% (1.9s) | 🥈 82.5% (19.8s) |
161+
| ozone-level-8hr | 🥈 91.2% (0.4s) | 🥇 91.6% (1.7s) | 72.9% (0.6s) |
162+
| pc4 | 🥇 95.3% (0.3s) | 🥈 90.9% (1.5s) | 25.7% (0.3s) |
163+
| phishingwebsites | 🥇 99.5% (0.5s) | 🥈 98.9% (3.6s) | 98.7% (10.0s) |
164+
| phoneme | 🥇 95.6% (0.3s) | 🥈 93.5% (2.1s) | 91.2% (2.0s) |
165+
| qsar-biodeg | 🥇 92.7% (0.4s) | 🥈 91.1% (5.2s) | 86.8% (0.3s) |
166+
| satellite | 🥈 98.7% (0.2s) | 🥇 99.5% (1.9s) | 98.5% (0.4s) |
167+
| sylvine | 🥇 98.5% (0.2s) | 🥈 97.1% (2.0s) | 96.5% (3.8s) |
168+
| wilt | 🥈 99.5% (0.2s) | 🥇 99.8% (1.8s) | 98.9% (0.5s) |
135169

136170
</details>
137171

@@ -140,28 +174,28 @@ ROC-AUC on 15% test set:
140174

141175
R² on 15% test set:
142176

143-
| dataset | LGBMRegressor | NeoLSSVM | SVR |
144-
|------------------------------:|----------------:|----------------:|-----------------:|
145-
| abalone | 🥈 56.2% (0.1s) | 🥇 59.5% (1.1s) | 51.3% (0.2s) |
146-
| boston | 🥇 91.7% (0.2s) | 🥈 89.3% (0.4s) | 35.1% (0.0s) |
147-
| brazilian_houses | 🥈 55.9% (0.4s) | 🥇 88.3% (1.5s) | 5.4% (2.0s) |
148-
| colleges | 🥇 58.5% (0.4s) | 🥈 43.7% (4.1s) | 40.2% (5.1s) |
149-
| diamonds | 🥇 98.2% (0.7s) | 🥈 95.2% (4.5s) | / |
150-
| elevators | 🥇 87.7% (0.4s) | 🥈 82.6% (2.6s) | / |
151-
| house_16h | 🥇 67.7% (0.3s) | 🥈 52.8% (2.4s) | / |
152-
| house_prices_nominal | 🥇 89.0% (0.6s) | 🥈 78.2% (1.3s) | -2.9% (0.3s) |
153-
| house_sales | 🥇 89.2% (1.3s) | 🥈 77.8% (2.2s) | / |
154-
| mip-2016-regression | 🥇 59.2% (0.4s) | 🥈 34.9% (2.6s) | -27.3% (0.1s) |
155-
| moneyball | 🥇 93.2% (0.2s) | 🥈 91.2% (0.6s) | 0.8% (0.1s) |
156-
| pol | 🥇 98.7% (0.3s) | 🥈 75.2% (1.7s) | / |
157-
| quake | -10.7% (0.2s) | 🥇 -0.1% (0.5s) | 🥈 -10.7% (0.0s) |
158-
| sat11-hand-runtime-regression | 🥇 78.3% (0.5s) | 🥈 61.7% (1.0s) | -56.3% (1.0s) |
159-
| sensory | 🥇 29.2% (0.2s) | 3.8% (0.4s) | 🥈 16.4% (0.0s) |
160-
| socmob | 🥇 79.6% (0.2s) | 🥈 72.5% (1.5s) | 30.8% (0.0s) |
161-
| space_ga | 🥇 70.3% (0.2s) | 🥈 43.7% (0.6s) | 35.9% (0.1s) |
162-
| tecator | 🥈 98.3% (0.1s) | 🥇 99.4% (0.2s) | 78.5% (0.0s) |
163-
| us_crime | 🥈 62.8% (0.4s) | 🥇 63.0% (0.8s) | 6.7% (0.2s) |
164-
| wine_quality | 🥇 45.6% (0.6s) | -8.0% (0.9s) | 🥈 16.4% (0.5s) |
177+
| dataset | LGBMRegressor | NeoLSSVM | SVR |
178+
|------------------------------:|----------------:|-----------------:|-----------------:|
179+
| abalone | 🥈 56.2% (0.1s) | 🥇 59.5% (2.5s) | 51.3% (0.7s) |
180+
| boston | 🥇 91.7% (0.2s) | 🥈 89.6% (1.1s) | 35.1% (0.0s) |
181+
| brazilian_houses | 🥈 55.9% (0.3s) | 🥇 88.4% (3.7s) | 5.4% (7.0s) |
182+
| colleges | 🥇 58.5% (0.4s) | 🥈 42.2% (6.6s) | 40.2% (15.1s) |
183+
| diamonds | 🥇 98.2% (0.3s) | 🥈 95.2% (13.7s) | / |
184+
| elevators | 🥇 87.7% (0.5s) | 🥈 82.6% (6.5s) | / |
185+
| house_16h | 🥇 67.7% (0.4s) | 🥈 52.8% (6.0s) | / |
186+
| house_prices_nominal | 🥇 89.0% (0.3s) | 🥈 78.3% (2.1s) | -2.9% (1.2s) |
187+
| house_sales | 🥇 89.2% (0.4s) | 🥈 77.8% (5.9s) | / |
188+
| mip-2016-regression | 🥇 59.2% (0.4s) | 🥈 34.9% (5.8s) | -27.3% (0.4s) |
189+
| moneyball | 🥇 93.2% (0.3s) | 🥈 91.3% (1.1s) | 0.8% (0.2s) |
190+
| pol | 🥇 98.7% (0.3s) | 🥈 74.9% (4.6s) | / |
191+
| quake | -10.7% (0.2s) | 🥇 -1.0% (1.6s) | 🥈 -10.7% (0.1s) |
192+
| sat11-hand-runtime-regression | 🥇 78.3% (0.4s) | 🥈 61.7% (2.1s) | -56.3% (5.1s) |
193+
| sensory | 🥇 29.2% (0.1s) | 3.0% (1.6s) | 🥈 16.4% (0.0s) |
194+
| socmob | 🥇 79.6% (0.2s) | 🥈 72.5% (6.6s) | 30.8% (0.1s) |
195+
| space_ga | 🥇 70.3% (0.3s) | 🥈 43.6% (1.5s) | 35.9% (0.2s) |
196+
| tecator | 🥈 98.3% (0.1s) | 🥇 99.4% (0.9s) | 78.5% (0.0s) |
197+
| us_crime | 🥈 62.8% (0.6s) | 🥇 63.0% (2.3s) | 6.7% (0.8s) |
198+
| wine_quality | 🥇 45.6% (0.2s) | 🥈 36.5% (2.8s) | 16.4% (1.6s) |
165199

166200
</details>
167201

0 commit comments

Comments
 (0)