-
Notifications
You must be signed in to change notification settings - Fork 238
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* data_processor mvp * add more typings * one output missing typing * remove redundant fit_transform method * typecheck simplifications * add cols positional index support to DataProcessor * add base processor
- Loading branch information
1 parent
97b855a
commit 08d3cae
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ tensorflow==2.4.* | |
easydict==1.9 | ||
pmlb==1.0.* | ||
tqdm<5.0 | ||
typeguard==2.13.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import List, Union | ||
|
||
from numpy import concatenate, ndarray, split, zeros | ||
from pandas import concat, DataFrame | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from typeguard import typechecked | ||
|
||
@typechecked | ||
class BaseProcessor(BaseEstimator, TransformerMixin): | ||
""" | ||
Base class for Data Preprocessing. It is a base version and should not be instantiated directly. | ||
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform. | ||
Args: | ||
num_cols (list of strings/list of ints): | ||
List of names of numerical columns or positional indexes (if pos_idx was set to True). | ||
cat_cols (list of strings/list of ints): | ||
List of names of categorical columns or positional indexes (if pos_idx was set to True). | ||
pos_idx (bool): | ||
Specifies if the passed col IDs are names or positional indexes (column numbers). | ||
""" | ||
def __init__(self, *, num_cols: Union[List[str], List[int]] = None, cat_cols: Union[List[str], List[int]] = None, | ||
pos_idx: bool = False): | ||
self.num_cols = [] if num_cols is None else num_cols | ||
self.cat_cols = [] if cat_cols is None else cat_cols | ||
|
||
self.num_col_idx_ = None | ||
self.cat_col_idx_ = None | ||
|
||
self.num_pipeline = None # To be overriden by child processors | ||
|
||
self.cat_pipeline = None # To be overriden by child processors | ||
|
||
self._types = None | ||
self.col_order_ = None | ||
self.pos_idx = pos_idx | ||
|
||
def fit(self, X: DataFrame): | ||
"""Fits the DataProcessor to a passed DataFrame. | ||
Args: | ||
X (DataFrame): | ||
DataFrame used to fit the processor parameters. | ||
Should be aligned with the num/cat columns defined in initialization. | ||
""" | ||
if self.pos_idx: | ||
self.num_cols = list(X.columns[self.num_cols]) | ||
self.cat_cols = list(X.columns[self.cat_cols]) | ||
self.col_order_ = [c for c in X.columns if c in self.num_cols + self.cat_cols] | ||
self._types = X.dtypes | ||
|
||
self.num_pipeline.fit(X[self.num_cols]) if self.num_cols else zeros([len(X), 0]) | ||
self.cat_pipeline.fit(X[self.cat_cols]) if self.cat_cols else zeros([len(X), 0]) | ||
|
||
return self | ||
|
||
def transform(self, X: DataFrame) -> ndarray: | ||
"""Transforms the passed DataFrame with the fit DataProcessor. | ||
Args: | ||
X (DataFrame): | ||
DataFrame used to fit the processor parameters. | ||
Should be aligned with the num/cat columns defined in initialization. | ||
Returns: | ||
transformed (ndarray): | ||
Processed version of the passed DataFrame. | ||
""" | ||
num_data = self.num_pipeline.transform(X[self.num_cols]) if self.num_cols else zeros([len(X), 0]) | ||
cat_data = self.cat_pipeline.transform(X[self.cat_cols]) if self.cat_cols else zeros([len(X), 0]) | ||
|
||
transformed = concatenate([num_data, cat_data], axis=1) | ||
|
||
self.num_col_idx_ = num_data.shape[1] | ||
self.cat_col_idx_ = self.num_col_idx_ + cat_data.shape[1] | ||
|
||
return transformed | ||
|
||
def inverse_transform(self, X: ndarray) -> DataFrame: | ||
"""Inverts the data transformation pipelines on a passed DataFrame. | ||
Args: | ||
X (ndarray): | ||
Numpy array to be brought back to the original data format. | ||
Should share the schema of data transformed by this DataProcessor. | ||
Can be used to revert transformations of training data or for | ||
Returns: | ||
result (DataFrame): | ||
DataFrame with inverted | ||
""" | ||
num_data, cat_data, _ = split(X, [self.num_col_idx_, self.cat_col_idx_], axis=1) | ||
|
||
num_data = self.num_pipeline.inverse_transform(num_data) if self.num_cols else zeros([len(X), 0]) | ||
cat_data = self.cat_pipeline.inverse_transform(cat_data) if self.cat_cols else zeros([len(X), 0]) | ||
|
||
result = concat([DataFrame(num_data, columns=self.num_cols), | ||
DataFrame(cat_data, columns=self.cat_cols),], axis=1) | ||
|
||
result = result.loc[:, self.col_order_] | ||
|
||
for col in result.columns: | ||
result[col]=result[col].astype(self._types[col]) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import List, Union | ||
|
||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder | ||
from typeguard import typechecked | ||
|
||
from ydata_synthetic.preprocessing.base_processor import BaseProcessor | ||
|
||
@typechecked | ||
class RegularDataProcessor(BaseProcessor): | ||
""" | ||
Main class for Regular/Tabular Data Preprocessing. | ||
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform. | ||
Args: | ||
num_cols (list of strings/list of ints): | ||
List of names of numerical columns or positional indexes (if pos_idx was set to True). | ||
cat_cols (list of strings/list of ints): | ||
List of names of categorical columns or positional indexes (if pos_idx was set to True). | ||
pos_idx (bool): | ||
Specifies if the passed col IDs are names or positional indexes (column numbers). | ||
""" | ||
def __init__(self, *, num_cols: Union[List[str], List[int]] = None, cat_cols: Union[List[str], List[int]] = None, | ||
pos_idx: bool = False): | ||
super().__init__(num_cols = num_cols, cat_cols = cat_cols, pos_idx = pos_idx) | ||
|
||
self.num_pipeline = Pipeline([ | ||
("scaler", MinMaxScaler()), | ||
]) | ||
|
||
self.cat_pipeline = Pipeline([ | ||
("encoder", OneHotEncoder(sparse=False, handle_unknown='ignore')) | ||
]) |