Skip to content

Commit

Permalink
Merge pull request #158 from The-Blitz/feat/checkpoint
Browse files Browse the repository at this point in the history
Adding Model checkpoint
  • Loading branch information
rodrigo-arenas authored Oct 21, 2024
2 parents f9a643a + 39baa19 commit 0e9f030
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 52 deletions.
2 changes: 2 additions & 0 deletions sklearn_genetic/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TimerStopping,
)
from .loggers import ProgressBar, LogbookSaver, TensorBoard
from .model_checkpoint import ModelCheckpoint

__all__ = [
"ProgressBar",
Expand All @@ -14,4 +15,5 @@
"TimerStopping",
"LogbookSaver",
"TensorBoard",
"ModelCheckpoint",
]
46 changes: 46 additions & 0 deletions sklearn_genetic/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pickle
from .base import BaseCallback
from .loggers import LogbookSaver
from copy import deepcopy


class ModelCheckpoint(BaseCallback):
def __init__(self, checkpoint_path, **dump_options):
self.checkpoint_path = checkpoint_path
self.dump_options = dump_options

def on_step(self, record=None, logbook=None, estimator=None):
try:
if logbook is not None and len(logbook) > 0:
logbook_saver = LogbookSaver(self.checkpoint_path, **self.dump_options) # noqa
logbook_saver.on_step(record, logbook, estimator)

estimator_state = {
"estimator": estimator.estimator,
"cv": estimator.cv,
"scoring": estimator.scoring,
"population_size": estimator.population_size,
"generations": estimator.generations,
"crossover_probability": estimator.crossover_probability,
"mutation_probability": estimator.mutation_probability,
"param_grid": estimator.param_grid,
"algorithm": estimator.algorithm,
"param_grid": estimator.param_grid,
}
checkpoint_data = {"estimator_state": estimator_state, "logbook": deepcopy(logbook)}
with open(self.checkpoint_path, "wb") as f:
pickle.dump(checkpoint_data, f)
print(f"Checkpoint save in {self.checkpoint_path}")

except Exception as e:
print(f"Error saving checkpoint: {e}")

def load(self):
"""Load the model state from the checkpoint file."""
try:
with open(self.checkpoint_path, "rb") as f:
checkpoint_data = pickle.load(f)
return checkpoint_data
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
171 changes: 121 additions & 50 deletions sklearn_genetic/genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from .utils.random import weighted_bool_individual
from .utils.tools import cxUniform, mutFlipBit, novelty_scorer

import pickle
import os
from .callbacks.model_checkpoint import ModelCheckpoint


class GASearchCV(BaseSearchCV):
"""
Expand Down Expand Up @@ -218,29 +222,29 @@ class GASearchCV(BaseSearchCV):
"""

def __init__(
self,
estimator,
cv=3,
param_grid=None,
scoring=None,
population_size=50,
generations=80,
crossover_probability=0.2,
mutation_probability=0.8,
tournament_size=3,
elitism=True,
verbose=True,
keep_top_k=1,
criteria="max",
algorithm="eaMuPlusLambda",
refit=True,
n_jobs=1,
pre_dispatch="2*n_jobs",
error_score=np.nan,
return_train_score=False,
log_config=None,
use_cache=True,
warm_start_configs=None,
self,
estimator,
cv=3,
param_grid=None,
scoring=None,
population_size=50,
generations=80,
crossover_probability=0.2,
mutation_probability=0.8,
tournament_size=3,
elitism=True,
verbose=True,
keep_top_k=1,
criteria="max",
algorithm="eaMuPlusLambda",
refit=True,
n_jobs=1,
pre_dispatch="2*n_jobs",
error_score=np.nan,
return_train_score=False,
log_config=None,
use_cache=True,
warm_start_configs=None,
):
self.estimator = estimator
self.cv = cv
Expand Down Expand Up @@ -311,7 +315,6 @@ def _register(self):
creator.create("FitnessMax", base.Fitness, weights=[self.criteria_sign, 1.0])
creator.create("Individual", list, fitness=creator.FitnessMax)


attributes = []
# Assign all the parameters defined in the param_grid
# It uses the distribution parameter to set the sampling function
Expand Down Expand Up @@ -358,7 +361,6 @@ def _register(self):
self._stats.register("fitness_max", np.max, axis=0)
self._stats.register("fitness_min", np.min, axis=0)


self.logbook = tools.Logbook()

def _initialize_population(self):
Expand Down Expand Up @@ -490,7 +492,7 @@ def evaluate(self, individual):
# Store the fitness result and the current generation parameters in the cache
self.fitness_cache[individual_key] = {
"fitness": fitness_result,
"current_generation_params": current_generation_params
"current_generation_params": current_generation_params,
}

return fitness_result
Expand Down Expand Up @@ -524,6 +526,16 @@ def fit(self, X, y, callbacks=None):
# Make sure the callbacks are valid
self.callbacks = check_callback(callbacks)

# Load state if a checkpoint exists
for callback in self.callbacks:
if isinstance(callback, ModelCheckpoint):
if os.path.exists(callback.checkpoint_path):
checkpoint_data = callback.load()
if checkpoint_data:
self.__dict__.update(checkpoint_data["estimator_state"]) # noqa
self.logbook = checkpoint_data["logbook"]
break

if callable(self.scoring):
self.scorer_ = self.scoring
self.metrics_list = [self.refit_metric]
Expand Down Expand Up @@ -601,6 +613,30 @@ def fit(self, X, y, callbacks=None):

return self

def save(self, filepath):
"""Save the current state of the GASearchCV instance to a file."""
try:
checkpoint_data = {"estimator_state": self.__dict__, "logbook": None}
if hasattr(self, "logbook"):
checkpoint_data["logbook"] = self.logbook
with open(filepath, "wb") as f:
pickle.dump(checkpoint_data, f)
print(f"GASearchCV model successfully saved to {filepath}")
except Exception as e:
print(f"Error saving GASearchCV: {e}")

def load(self, filepath):
"""Load a GASearchCV instance from a file."""
try:
with open(filepath, "rb") as f:
checkpoint_data = pickle.load(f)
for key, value in checkpoint_data["estimator_state"].items():
setattr(self, key, value)
self.logbook = checkpoint_data["logbook"]
print(f"GASearchCV model successfully loaded from {filepath}")
except Exception as e:
print(f"Error loading GASearchCV: {e}")

def _select_algorithm(self, pop, stats, hof):
"""
It selects the algorithm to run from the sklearn_genetic.algorithms module
Expand Down Expand Up @@ -895,28 +931,28 @@ class GAFeatureSelectionCV(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
"""

def __init__(
self,
estimator,
cv=3,
scoring=None,
population_size=50,
generations=80,
crossover_probability=0.2,
mutation_probability=0.8,
tournament_size=3,
elitism=True,
max_features=None,
verbose=True,
keep_top_k=1,
criteria="max",
algorithm="eaMuPlusLambda",
refit=True,
n_jobs=1,
pre_dispatch="2*n_jobs",
error_score=np.nan,
return_train_score=False,
log_config=None,
use_cache=True,
self,
estimator,
cv=3,
scoring=None,
population_size=50,
generations=80,
crossover_probability=0.2,
mutation_probability=0.8,
tournament_size=3,
elitism=True,
max_features=None,
verbose=True,
keep_top_k=1,
criteria="max",
algorithm="eaMuPlusLambda",
refit=True,
n_jobs=1,
pre_dispatch="2*n_jobs",
error_score=np.nan,
return_train_score=False,
log_config=None,
use_cache=True,
):
self.estimator = estimator
self.cv = cv
Expand Down Expand Up @@ -1083,7 +1119,7 @@ def evaluate(self, individual):
# Penalize individuals with more features than the max_features parameter

if self.max_features and (
n_selected_features > self.max_features or n_selected_features == 0
n_selected_features > self.max_features or n_selected_features == 0
):
score = -self.criteria_sign * 100000

Expand All @@ -1094,7 +1130,7 @@ def evaluate(self, individual):
# Store the fitness result and the current generation features in the cache
self.fitness_cache[individual_key] = {
"fitness": fitness_result,
"current_generation_features": current_generation_features
"current_generation_features": current_generation_features,
}

return fitness_result
Expand Down Expand Up @@ -1131,6 +1167,16 @@ def fit(self, X, y, callbacks=None):
# Make sure the callbacks are valid
self.callbacks = check_callback(callbacks)

# Load state if a checkpoint exists
for callback in self.callbacks:
if isinstance(callback, ModelCheckpoint):
if os.path.exists(callback.checkpoint_path):
checkpoint_data = callback.load()
if checkpoint_data:
self.__dict__.update(checkpoint_data["estimator_state"]) # noqa
self.logbook = checkpoint_data["logbook"]
break

if callable(self.scoring):
self.scorer_ = self.scoring
self.metrics_list = [self.refit_metric]
Expand Down Expand Up @@ -1192,6 +1238,31 @@ def fit(self, X, y, callbacks=None):

return self

def save(self, filepath):
"""Save the current state of the GAFeatureSelectionCV instance to a file."""
try:
checkpoint_data = {"estimator_state": self.__dict__, "logbook": None}
if hasattr(self, "logbook"):
checkpoint_data["logbook"] = self.logbook

with open(filepath, "wb") as f:
pickle.dump(checkpoint_data, f)
print(f"GAFeatureSelectionCV model successfully saved to {filepath}")
except Exception as e:
print(f"Error saving GAFeatureSelectionCV: {e}")

def load(self, filepath):
"""Load a GAFeatureSelectionCV instance from a file."""
try:
with open(filepath, "rb") as f:
checkpoint_data = pickle.load(f)
for key, value in checkpoint_data["estimator_state"].items():
setattr(self, key, value)
self.logbook = checkpoint_data["logbook"]
print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa
except Exception as e:
print(f"Error loading GAFeatureSelectionCV: {e}")

def _select_algorithm(self, pop, stats, hof):
"""
It selects the algorithm to run from the sklearn_genetic.algorithms module
Expand Down
Loading

0 comments on commit 0e9f030

Please sign in to comment.