Skip to content

Commit 5776e7e

Browse files
authored
Merge pull request #269 from juaml/add/target_generate
[ENH] Generate the target from features
2 parents d36bc82 + 56a3dc6 commit 5776e7e

File tree

18 files changed

+1343
-132
lines changed

18 files changed

+1343
-132
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
Target Generation
3+
=================
4+
5+
This example uses the ``iris`` dataset and tests a regression model in which
6+
the target variable is generated from some features within the cross-validation
7+
procedure. We will use the Iris dataset and generate a target variable using
8+
PCA on the petal features. Then, we will evaluate if a regression model can
9+
predict the generated target from the sepal features
10+
11+
.. include:: ../../links.inc
12+
"""
13+
# Authors: Federico Raimondo <[email protected]>
14+
# License: AGPL
15+
16+
from seaborn import load_dataset
17+
from julearn import run_cross_validation
18+
from julearn.pipeline import PipelineCreator
19+
from julearn.utils import configure_logging
20+
21+
###############################################################################
22+
# Set the logging level to info to see extra information.
23+
configure_logging(level="DEBUG")
24+
25+
###############################################################################
26+
df_iris = load_dataset("iris")
27+
28+
29+
###############################################################################
30+
# As features, we will use the sepal length, width and petal length.
31+
# We will try to predict the species.
32+
33+
X = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
34+
y = "__generated__" # to indicate to julearn that the target will be generated
35+
36+
37+
# Define our feature types
38+
X_types = {
39+
"sepal": ["sepal_length", "sepal_width"],
40+
"petal": ["petal_length", "petal_width"],
41+
}
42+
43+
###############################################################################
44+
# We now use a Pipeline Creator to create the pipeline that will generate the
45+
# features. This special pipeline should be configured to be a "transformer"
46+
# and apply to the "petal" feature types.
47+
48+
target_creator = PipelineCreator(problem_type="transformer", apply_to="petal")
49+
target_creator.add("pca", n_components=2)
50+
# Select only the first component
51+
target_creator.add("pick_columns", keep="pca__pca0")
52+
53+
54+
###############################################################################
55+
# We now create the pipeline that will be used to predict the target. This
56+
# pipeline will be a regression pipeline. The step previous to the model should
57+
# be the the `generate_target`, applying to the "petal" features and using the
58+
# target_creator pipeline as the transformer.
59+
creator = PipelineCreator(problem_type="regression")
60+
creator.add("zscore", apply_to="*")
61+
creator.add("generate_target", apply_to="petal", transformer=target_creator)
62+
creator.add("linreg", apply_to="sepal")
63+
64+
###############################################################################
65+
# We finally evaluate the model within the cross validation.
66+
scores, model = run_cross_validation(
67+
X=X,
68+
y=y,
69+
X_types=X_types,
70+
data=df_iris,
71+
model=creator,
72+
return_estimator="final",
73+
cv=2,
74+
)
75+
76+
print(scores["test_score"]) # type: ignore

julearn/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def _validata_api_params( # noqa: C901
209209

210210
wrap_score = False
211211
if isinstance(model, (PipelineCreator, list)):
212+
logger.debug(
213+
"Generating pipeline from PipelineCreator or list of them"
214+
)
212215
if preprocess is not None:
213216
raise_error(
214217
"If model is a PipelineCreator (or list of), "
@@ -242,6 +245,7 @@ def _validata_api_params( # noqa: C901
242245
expanded_models.extend(m.split())
243246

244247
has_target_transformer = expanded_models[-1]._added_target_transformer
248+
has_target_generator = expanded_models[-1]._added_target_generator
245249
all_pipelines = [
246250
model.to_pipeline(X_types=X_types, search_params=search_params)
247251
for model in expanded_models
@@ -255,12 +259,16 @@ def _validata_api_params( # noqa: C901
255259
pipeline = all_pipelines[0]
256260

257261
if has_target_transformer:
262+
logger.debug("Pipeline has target transformer")
258263
if isinstance(pipeline, BaseSearchCV):
259264
last_step = pipeline.estimator[-1] # type: ignore
260265
else:
261266
last_step = pipeline[-1]
262267
if not last_step.can_inverse_transform():
263268
wrap_score = True
269+
if has_target_generator:
270+
logger.debug("Pipeline has target generator")
271+
wrap_score = True
264272
problem_type = model[0].problem_type
265273

266274
elif not isinstance(model, (str, BaseEstimator)):
@@ -317,12 +325,15 @@ def _validata_api_params( # noqa: C901
317325
f"The following model_params are incorrect: {unused_params}"
318326
)
319327
has_target_transformer = pipeline_creator._added_target_transformer
328+
has_target_generator = pipeline_creator._added_target_generator
320329
pipeline = pipeline_creator.to_pipeline(
321330
X_types=X_types, search_params=search_params
322331
)
323332

324333
if has_target_transformer and not pipeline[-1].can_inverse_transform():
325334
wrap_score = True
335+
if has_target_generator:
336+
wrap_score = True
326337

327338
# Log some information
328339
logger.info("= Data Information =")

julearn/base/column_types.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# Sami Hamdan <[email protected]>
55
# License: AGPL
66

7-
from typing import Callable, Union
7+
from typing import Any, Callable, Union
88

99
from sklearn.compose import make_column_selector
1010

1111
from ..utils.logging import raise_error
1212

1313

1414
ColumnTypesLike = Union[list[str], set[str], str, "ColumnTypes"]
15+
ColumnTypesDict = dict[str, ColumnTypesLike]
1516

1617

1718
def change_column_type(column: str, new_type: str):
@@ -240,6 +241,42 @@ def __eq__(self, other: Union["ColumnTypes", str]):
240241
other = other if isinstance(other, ColumnTypes) else ColumnTypes(other)
241242
return self._column_types == other._column_types
242243

244+
def __and__(self, other: "ColumnTypes"):
245+
"""Get the intersection of the column_types.
246+
247+
Parameters
248+
----------
249+
other : ColumnTypes
250+
The other column_types to get the intersection with.
251+
252+
Returns
253+
-------
254+
ColumnTypes
255+
The intersection of the column_types.
256+
257+
"""
258+
return ColumnTypes(self._column_types & other._column_types)
259+
260+
def __or__(self, other: "ColumnTypes"):
261+
"""Get the union of the column_types.
262+
263+
Parameters
264+
----------
265+
other : ColumnTypes
266+
The other column_types to get the union with.
267+
268+
Returns
269+
-------
270+
ColumnTypes
271+
The union of the column_types.
272+
273+
"""
274+
return ColumnTypes(self._column_types | other._column_types)
275+
276+
def __len__(self):
277+
"""Get the number of column_types."""
278+
return len(self._column_types)
279+
243280
def __iter__(self):
244281
"""Iterate over the column_types."""
245282

@@ -251,6 +288,22 @@ def __repr__(self):
251288
f"ColumnTypes<types={self._column_types}; pattern={self.pattern}>"
252289
)
253290

291+
def filter(self, X_types: dict[str, Any]) -> dict[str, Any]: # noqa: N803
292+
"""Filter the X_types based on the column_types.
293+
294+
Parameters
295+
----------
296+
X_types : dict
297+
The types of the columns.
298+
299+
Returns
300+
-------
301+
dict:
302+
The filtered X_types.
303+
304+
"""
305+
return {k: v for k, v in X_types.items() if k in self._column_types}
306+
254307
def copy(self) -> "ColumnTypes":
255308
"""Get a copy of the ColumnTypes.
256309

julearn/base/tests/test_column_types.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,71 @@ def test_ColumnTypes_add(
253253
"""
254254
summed = ColumnTypes(left).add(right)
255255
assert summed == ColumnTypes(result)
256+
257+
258+
@pytest.mark.parametrize(
259+
"left,right,result",
260+
[
261+
(
262+
["continuous"],
263+
["continuous"],
264+
["continuous"],
265+
),
266+
(
267+
["cont", "cat"],
268+
"cat",
269+
["cat"],
270+
),
271+
],
272+
)
273+
def test_ColumnTypes_and(
274+
left: ColumnTypesLike, right: ColumnTypesLike, result: ColumnTypesLike
275+
) -> None:
276+
"""Test the ColumnTypes addition.
277+
278+
Parameters
279+
----------
280+
left : ColumnTypesLike
281+
The left hand side of the addition.
282+
right : ColumnTypesLike
283+
The right hand side of the addition.
284+
result : ColumnTypes
285+
The expected result.
286+
287+
"""
288+
anded = ColumnTypes(left) & ColumnTypes(right)
289+
assert anded == ColumnTypes(result)
290+
291+
292+
@pytest.mark.parametrize(
293+
"left,right,result",
294+
[
295+
(
296+
["continuous"],
297+
["continuous"],
298+
["continuous"],
299+
),
300+
(
301+
["cont", "cat"],
302+
"cat",
303+
["cont", "cat"],
304+
),
305+
],
306+
)
307+
def test_ColumnTypes_or(
308+
left: ColumnTypesLike, right: ColumnTypesLike, result: ColumnTypesLike
309+
) -> None:
310+
"""Test the ColumnTypes addition.
311+
312+
Parameters
313+
----------
314+
left : ColumnTypesLike
315+
The left hand side of the addition.
316+
right : ColumnTypesLike
317+
The right hand side of the addition.
318+
result : ColumnTypes
319+
The expected result.
320+
321+
"""
322+
orred = ColumnTypes(left) | ColumnTypes(right)
323+
assert orred == ColumnTypes(result)

julearn/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def model(request: FixtureRequest) -> str:
243243
return request.param
244244

245245

246-
@fixture(params=["regression", "classification"], scope="function")
246+
@fixture(
247+
params=["regression", "classification", "transformer"], scope="function"
248+
)
247249
def problem_type(request: FixtureRequest) -> str:
248250
"""Return different problem types.
249251

0 commit comments

Comments
 (0)