Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] refactor _clone to a plugin structure #381

Merged
merged 34 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b470a75
start
fkiraly Nov 9, 2024
585e086
Update _clone_base.py
fkiraly Nov 9, 2024
4d4ecf5
finished
fkiraly Nov 9, 2024
bb74f63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2024
0684f01
linting, private sklearn cloner
fkiraly Nov 9, 2024
9398271
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly Nov 9, 2024
1a2fcbd
safe import
fkiraly Nov 9, 2024
ef8c0ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2024
7d89862
linting
fkiraly Nov 9, 2024
d90aa44
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly Nov 9, 2024
c5527ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2024
1dfc0be
isolate imports
fkiraly Nov 9, 2024
6c4e974
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly Nov 9, 2024
22b0e64
fix
fkiraly Nov 9, 2024
91a876d
linting
fkiraly Nov 9, 2024
a8dd391
Revert "linting"
fkiraly Nov 9, 2024
35f686a
Reapply "linting"
fkiraly Nov 9, 2024
8efb6ed
Revert "Reapply "linting""
fkiraly Nov 9, 2024
ae08d1b
Reapply "Reapply "linting""
fkiraly Nov 9, 2024
a2c4820
Update _clone_plugins.py
fkiraly Nov 9, 2024
a218815
Update conftest.py
fkiraly Nov 9, 2024
cd5699a
conftest
fkiraly Nov 9, 2024
e167991
allow raise from clone
fkiraly Nov 9, 2024
337cd0c
Update _clone_base.py
fkiraly Nov 9, 2024
dd03809
Update conftest.py
fkiraly Nov 9, 2024
cf42711
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2024
01f3dba
Update conftest.py
fkiraly Nov 10, 2024
91a3aee
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly Nov 10, 2024
1c944d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2024
c12bfc1
Update conftest.py
fkiraly Nov 10, 2024
88feb5d
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly Nov 10, 2024
6aa2221
Update conftest.py
fkiraly Nov 10, 2024
894c694
test for retaining sklearn config
fkiraly Nov 10, 2024
99b5a1f
fix test
fkiraly Nov 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 2 additions & 105 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class name: BaseEstimator
from typing import List

from skbase._exceptions import NotFittedError
from skbase.base._clone_base import _check_clone, _clone
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager

Expand Down Expand Up @@ -175,7 +176,7 @@ def clone(self):
------
RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
"""
self_clone = _clone(self)
self_clone = _clone(self, base_cls=BaseObject)
if self.get_config()["check_clone"]:
_check_clone(original=self, clone=self_clone)
return self_clone
Expand Down Expand Up @@ -1653,107 +1654,3 @@ def _get_fitted_params(self):
fitted parameters, keyed by names of fitted parameter
"""
return self._get_fitted_params_default()


# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True):
"""Construct a new unfitted estimator with the same parameters.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.

Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.

Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.

Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
estimator_type = type(estimator)
if estimator_type is dict:
return {k: _clone(v, safe=safe) for k, v in estimator.items()}
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([_clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
if not safe:
return deepcopy(estimator)
else:
if isinstance(estimator, type):
raise TypeError(
"Cannot clone object. "
+ "You should provide an instance of "
+ "scikit-learn estimator instead of a class."
)
else:
raise TypeError(
"Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method." % (repr(estimator), type(estimator))
)

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = _clone(param, safe=False)
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
)

# This is an extension to the original sklearn implementation
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
new_object.set_config(**estimator.get_config())

return new_object


def _check_clone(original, clone):
from skbase.utils.deep_equals import deep_equals

self_params = original.get_params(deep=False)

# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)

clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)
129 changes: 129 additions & 0 deletions skbase/base/_clone_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
# Elements of BaseObject reuse code developed in scikit-learn. These elements
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Logic and plugins for cloning objects.

This module contains logic for cloning objects:

_clone(estimator, *, safe=True, plugins=None) - central entry point for cloning
_check_clone(original, clone) - validation utility to check clones

Default plugins for _clone are stored in _clone_plugins:

DEFAULT_CLONE_PLUGINS - list with default plugins for cloning

Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:

* check(obj) -> boolean - fast checker whether plugin applies
* clone(obj) -> type(obj) - method to clone obj
"""
__all__ = ["_clone", "_check_clone"]

from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
skbase.base._clone_plugins
begins an import cycle.


# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None):
"""Construct a new unfitted estimator with the same parameters.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.

Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If ``safe`` is False, clone will fall back to a deep copy on objects
that are not estimators.
clone_plugins : list of BaseCloner clone plugins, concrete descendant classes.
Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface.
If passed, will work through clone plugins in ``clone_plugins``
before working through ``DEFAULT_CLONE_PLUGINS``. To override
a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with
the same ``_check`` logis is present in ``clone_plugins``.
base_cls : reference to BaseObject
Reference to the BaseObject class from skbase.base._base.
Present for easy reference, fast imports, and potential extensions.

Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.

Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
# handle cloning plugins:
# if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS
# if provided by user, work through user provided plugins first, then defaults
if clone_plugins is not None:
all_plugins = clone_plugins.copy()
all_plugins.append(DEFAULT_CLONE_PLUGINS.copy())
else:
all_plugins = DEFAULT_CLONE_PLUGINS

for cloner_plugin in all_plugins:
cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls)
# we clone with the first plugin in the list that:
# 1. claims it is applicable, via check
# 2. does not produce an Exception when cloning
if cloner.check(obj=estimator):
return cloner.clone(obj=estimator)

raise RuntimeError(
"Error in skbase _clone, catch-all plugin did not catch all "
"remaining cases. This is likely due to custom modification of the module."
)


def _check_clone(original, clone):
"""Check that clone is a valid clone of original.

Called from BaseObject.clone to validate the clone, if
the config flag check_clone is set to True.

Parameters
----------
original : object
The original object.
clone : object
The cloned object.

Raises
------
RuntimeError
If the clone is not a valid clone of the original.
"""
from skbase.utils.deep_equals import deep_equals

self_params = original.get_params(deep=False)

# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)

clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)
Loading
Loading