-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] refactor _clone
to a plugin structure
#381
Merged
Merged
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
b470a75
start
fkiraly 585e086
Update _clone_base.py
fkiraly 4d4ecf5
finished
fkiraly bb74f63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0684f01
linting, private sklearn cloner
fkiraly 9398271
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly 1a2fcbd
safe import
fkiraly ef8c0ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7d89862
linting
fkiraly d90aa44
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly c5527ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1dfc0be
isolate imports
fkiraly 6c4e974
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly 22b0e64
fix
fkiraly 91a876d
linting
fkiraly a8dd391
Revert "linting"
fkiraly 35f686a
Reapply "linting"
fkiraly 8efb6ed
Revert "Reapply "linting""
fkiraly ae08d1b
Reapply "Reapply "linting""
fkiraly a2c4820
Update _clone_plugins.py
fkiraly a218815
Update conftest.py
fkiraly cd5699a
conftest
fkiraly e167991
allow raise from clone
fkiraly 337cd0c
Update _clone_base.py
fkiraly dd03809
Update conftest.py
fkiraly cf42711
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 01f3dba
Update conftest.py
fkiraly 91a3aee
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly 1c944d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c12bfc1
Update conftest.py
fkiraly 88feb5d
Merge branch 'clone_plugins' of https://github.com/sktime/skbase into…
fkiraly 6aa2221
Update conftest.py
fkiraly 894c694
test for retaining sklearn config
fkiraly 99b5a1f
fix test
fkiraly File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# -*- coding: utf-8 -*- | ||
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) | ||
# Elements of BaseObject reuse code developed in scikit-learn. These elements | ||
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For | ||
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING | ||
"""Logic and plugins for cloning objects. | ||
|
||
This module contains logic for cloning objects: | ||
|
||
_clone(estimator, *, safe=True, plugins=None) - central entry point for cloning | ||
_check_clone(original, clone) - validation utility to check clones | ||
|
||
Default plugins for _clone are stored in _clone_plugins: | ||
|
||
DEFAULT_CLONE_PLUGINS - list with default plugins for cloning | ||
|
||
Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods: | ||
|
||
* check(obj) -> boolean - fast checker whether plugin applies | ||
* clone(obj) -> type(obj) - method to clone obj | ||
""" | ||
__all__ = ["_clone", "_check_clone"] | ||
|
||
from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS | ||
|
||
|
||
# Adapted from sklearn's `_clone_parametrized()` | ||
def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None): | ||
"""Construct a new unfitted estimator with the same parameters. | ||
|
||
Clone does a deep copy of the model in an estimator | ||
without actually copying attached data. It returns a new estimator | ||
with the same parameters that has not been fitted on any data. | ||
|
||
Parameters | ||
---------- | ||
estimator : {list, tuple, set} of estimator instance or a single estimator instance | ||
The estimator or group of estimators to be cloned. | ||
safe : bool, default=True | ||
If ``safe`` is False, clone will fall back to a deep copy on objects | ||
that are not estimators. | ||
clone_plugins : list of BaseCloner clone plugins, concrete descendant classes. | ||
Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface. | ||
If passed, will work through clone plugins in ``clone_plugins`` | ||
before working through ``DEFAULT_CLONE_PLUGINS``. To override | ||
a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with | ||
the same ``_check`` logis is present in ``clone_plugins``. | ||
base_cls : reference to BaseObject | ||
Reference to the BaseObject class from skbase.base._base. | ||
Present for easy reference, fast imports, and potential extensions. | ||
|
||
Returns | ||
------- | ||
estimator : object | ||
The deep copy of the input, an estimator if input is an estimator. | ||
|
||
Notes | ||
----- | ||
If the estimator's `random_state` parameter is an integer (or if the | ||
estimator doesn't have a `random_state` parameter), an *exact clone* is | ||
returned: the clone and the original estimator will give the exact same | ||
results. Otherwise, *statistical clone* is returned: the clone might | ||
return different results from the original estimator. More details can be | ||
found in :ref:`randomness`. | ||
""" | ||
# handle cloning plugins: | ||
# if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS | ||
# if provided by user, work through user provided plugins first, then defaults | ||
if clone_plugins is not None: | ||
all_plugins = clone_plugins.copy() | ||
all_plugins.append(DEFAULT_CLONE_PLUGINS.copy()) | ||
else: | ||
all_plugins = DEFAULT_CLONE_PLUGINS | ||
|
||
for cloner_plugin in all_plugins: | ||
cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls) | ||
# we clone with the first plugin in the list that: | ||
# 1. claims it is applicable, via check | ||
# 2. does not produce an Exception when cloning | ||
if cloner.check(obj=estimator): | ||
return cloner.clone(obj=estimator) | ||
|
||
raise RuntimeError( | ||
"Error in skbase _clone, catch-all plugin did not catch all " | ||
"remaining cases. This is likely due to custom modification of the module." | ||
) | ||
|
||
|
||
def _check_clone(original, clone): | ||
"""Check that clone is a valid clone of original. | ||
|
||
Called from BaseObject.clone to validate the clone, if | ||
the config flag check_clone is set to True. | ||
|
||
Parameters | ||
---------- | ||
original : object | ||
The original object. | ||
clone : object | ||
The cloned object. | ||
|
||
Raises | ||
------ | ||
RuntimeError | ||
If the clone is not a valid clone of the original. | ||
""" | ||
from skbase.utils.deep_equals import deep_equals | ||
|
||
self_params = original.get_params(deep=False) | ||
|
||
# check that all attributes are written to the clone | ||
for attrname in self_params.keys(): | ||
if not hasattr(clone, attrname): | ||
raise RuntimeError( | ||
f"error in {original}.clone, __init__ must write all arguments " | ||
f"to self and not mutate them, but {attrname} was not found. " | ||
f"Please check __init__ of {original}." | ||
) | ||
|
||
clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()} | ||
|
||
# check equality of parameters post-clone and pre-clone | ||
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True) | ||
if not clone_attrs_valid: | ||
raise RuntimeError( | ||
f"error in {original}.clone, __init__ must write all arguments " | ||
f"to self and not mutate them, but this is not the case. " | ||
f"Error on equality check of arguments (x) vs parameters (y): {msg}" | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Cyclic import Note