diff --git a/sdgx/data_models/inspectors/personal.py b/sdgx/data_models/inspectors/personal.py index fb8d2454..3a8e84dd 100644 --- a/sdgx/data_models/inspectors/personal.py +++ b/sdgx/data_models/inspectors/personal.py @@ -74,6 +74,31 @@ def domain_verification(self, each_sample): return True +class ChinaMainlandAddressInspector(RegexInspector): + + # This regular expression does not take effect and is only for reference by developers. + # pattern = r"^[\u4e00-\u9fa5]{2,}(省|自治区|特别行政区|市)|[\u4e00-\u9fa5]{2,}(市|区|县|自治州|自治县|县级市|地区|盟|林区)?|[\u4e00-\u9fa5]{0,}(街道|镇|乡)?|[\u4e00-\u9fa5]{0,}(路|街|巷|弄)?|[\u4e00-\u9fa5]{0,}(号|弄)?$" + + pattern = r"^[\u4e00-\u9fa5]{2,}(省|自治区|特别行政区|市|县|村|弄|乡|路|街)" + + pii = True + + data_type_name = "china_mainland_address" + + _inspect_level = 30 + + def domain_verification(self, each_sample): + # CHN address should be between 8 - 30 characters + if len(each_sample) < 8: + return False + if len(each_sample) > 30: + return False + # notice to distinguishing from the company name + if each_sample.endswith("公司"): + return False + return True + + @hookimpl def register(manager): manager.register("EmailInspector", EmailInspector) @@ -85,3 +110,5 @@ def register(manager): manager.register("ChinaMainlandPostCode", ChinaMainlandPostCode) manager.register("ChinaMainlandUnifiedSocialCreditCode", ChinaMainlandUnifiedSocialCreditCode) + + manager.register("ChinaMainlandAddressInspector", ChinaMainlandAddressInspector) diff --git a/sdgx/data_processors/formatters/base.py b/sdgx/data_processors/formatters/base.py index 52cdf320..83b810b9 100644 --- a/sdgx/data_processors/formatters/base.py +++ b/sdgx/data_processors/formatters/base.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import pandas as pd + from sdgx.data_processors.base import DataProcessor @@ -16,3 +20,27 @@ class Formatter(DataProcessor): - :ref:`Transformer` sometimes implements some functions with the help of Formatter. """ + + def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame: + """Convert processed data into raw data. + + Args: + processed_data (pd.DataFrame): Processed data + + Returns: + pd.DataFrame: Raw data + """ + return self.post_processing(processed_data) + + def post_processing(self, processed_data: pd.DataFrame) -> pd.DataFrame: + """ + For formatter, please rewrite this method. + + Args: + processed_data (pd.DataFrame): Processed data + + Returns: + pd.DataFrame: Raw data + """ + + return processed_data diff --git a/sdgx/data_processors/formatters/int.py b/sdgx/data_processors/formatters/int.py new file mode 100644 index 00000000..8e6e319f --- /dev/null +++ b/sdgx/data_processors/formatters/int.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd + +from sdgx.data_models.metadata import Metadata +from sdgx.data_processors.extension import hookimpl +from sdgx.data_processors.formatters.base import Formatter +from sdgx.utils import logger + + +class IntValueFormatter(Formatter): + """ + Formatter class for handling Int values in pd.DataFrame. + """ + + int_columns = None + + def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]): + """ + Fit method for the formatter. + + Formatter need to use metadata to record which columns belong to the int type, and convert them back to the int type during post-processing. + """ + + # get from metadata + self.int_columns = metadata.get("int_columns") + + logger.info("IntValueFormatter Fitted.") + + return + + def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame: + """ + Convert method to handle missing values in the input data. + """ + + logger.info("Converting data using IntValueFormatter... Finished (No Action).") + + return raw_data + + def post_processing(self, processed_data: pd.DataFrame) -> pd.DataFrame: + """ + post_processing method for the formatter. + + Does not require any action. + """ + + for col in self.int_columns: + processed_data[col] = processed_data[col].astype(int) + + logger.info("Data reverse-converted by IntValueFormatter.") + + return processed_data + + pass + + +@hookimpl +def register(manager): + manager.register("IntValueFormatter", IntValueFormatter) diff --git a/sdgx/data_processors/manager.py b/sdgx/data_processors/manager.py index 3a2e8f8f..34105d00 100644 --- a/sdgx/data_processors/manager.py +++ b/sdgx/data_processors/manager.py @@ -10,19 +10,84 @@ class DataProcessorManager(Manager): + """ + This is a plugin management class for data processing components. + + Properties: + - register_type: Specifies the type of data processors to register. + - project_name: Stores the project name from the extension module. + - hookspecs_model: Stores the hook specifications model from the extension module. + - preset_default_processors: Stores a list of default processor names in lowercase. + - registed_data_processors: Property that returns the registered data processors. + - registed_default_processor_list: Property that returns the registered default data processors. + + Methods: + - load_all_local_model: Loads all local models for formatters, generators, samplers, and transformers. + - init_data_processor: Initializes a data processor with the given name and keyword arguments. + - init_all_processors: Initializes all registered data processors with optional keyword arguments. + - init_default_processors: Initializes default processors that are both registered and preset. + + """ + register_type = DataProcessor + project_name = PROJECT_NAME + hookspecs_model = extension + preset_defalut_processors = [p.lower() for p in ["NonValueTransformer", "IntValueFormatter"]] + @property def registed_data_processors(self): + """ + This property returns all registered data processors + """ return self.registed_cls + @property + def registed_default_processor_list(self): + """ + This property returns all registered default data processors + """ + registed_processor_list = self.registed_data_processors.keys() + + # calculate intersection + target_processors = list( + set(registed_processor_list).intersection(self.preset_defalut_processors) + ) + + return target_processors + def load_all_local_model(self): + """ + loads all local models + """ self._load_dir(data_processors.formatters) self._load_dir(data_processors.generators) self._load_dir(data_processors.samplers) self._load_dir(data_processors.transformers) def init_data_processor(self, processor_name, **kwargs: dict[str, Any]) -> DataProcessor: + """ + Initializes a data processor with the given name and parameters + """ return self.init(processor_name, **kwargs) + + def init_all_processors(self, **kwargs: Any) -> list[DataProcessor]: + """ + Initializes all registered data processors + """ + return [ + self.init(processor_name, **kwargs) + for processor_name in self.registed_data_processors.keys() + ] + + def init_default_processors(self, **kwargs: Any) -> list[DataProcessor]: + """ + Initializes all default data processors + """ + + return [ + self.init(processor_name, **kwargs) + for processor_name in self.registed_default_processor_list + ] diff --git a/sdgx/data_processors/transformers/base.py b/sdgx/data_processors/transformers/base.py index 0f4c1f96..5f40ff27 100644 --- a/sdgx/data_processors/transformers/base.py +++ b/sdgx/data_processors/transformers/base.py @@ -1,4 +1,9 @@ +import pandas as pd + +from sdgx.data_loader import DataLoader +from sdgx.data_models.metadata import Metadata from sdgx.data_processors.base import DataProcessor +from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader class Transformer(DataProcessor): @@ -10,3 +15,20 @@ class Transformer(DataProcessor): To achieve that, Transformer can use :ref:`Formatter` and :ref:`Inspector` to help. """ + + def fit(self, metadata: Metadata | None = None, tabular_data: DataLoader | pd.DataFrame = None): + """ + Fit method for the transformer. + """ + + return + + @staticmethod + def delete_column(tabular_data, column_name): + + pass + + @staticmethod + def attach_columns(tabular_data, new_columns): + + pass diff --git a/sdgx/data_processors/transformers/column_order.py b/sdgx/data_processors/transformers/column_order.py new file mode 100644 index 00000000..253dbe47 --- /dev/null +++ b/sdgx/data_processors/transformers/column_order.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from pandas import DataFrame + +from sdgx.data_models.metadata import Metadata +from sdgx.data_processors.extension import hookimpl +from sdgx.data_processors.transformers.base import Transformer +from sdgx.utils import logger + + +class ColumnOrderTransformer(Transformer): + """ + Transformer class for handling missing values in data. + + This Transformer is mainly used as a reference for Transformer to facilitate developers to quickly understand the role of Transformer. + """ + + column_list: list = None + """ + The list of tabular data's columns. + """ + + def fit(self, metadata: Metadata | None = None): + """ + Fit method for the transformer. + + Remember the order of the columns. + """ + + self.column_list = list(metadata.column_list) + + logger.info("ColumnOrderTransformer Fitted.") + + return + + def convert(self, raw_data: DataFrame) -> DataFrame: + """ + Convert method to handle missing values in the input data. + """ + logger.info("Converting data using ColumnOrderTransformer...") + logger.info("Converting data using ColumnOrderTransformer... Finished (No action).") + + return raw_data + + def reverse_convert(self, processed_data: DataFrame) -> DataFrame: + """ + Reverse_convert method for the transformer. + """ + + logger.info("Data reverse-converted by ColumnOrderTransformer.") + + return processed_data + + pass + + +@hookimpl +def register(manager): + manager.register("ColumnOrderTransformer", ColumnOrderTransformer) diff --git a/sdgx/data_processors/transformers/discrete.py b/sdgx/data_processors/transformers/discrete.py new file mode 100644 index 00000000..ad77621e --- /dev/null +++ b/sdgx/data_processors/transformers/discrete.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import pandas as pd +from sklearn.preprocessing import OneHotEncoder + +from sdgx.data_loader import DataLoader +from sdgx.data_models.metadata import Metadata +from sdgx.data_processors.extension import hookimpl +from sdgx.data_processors.transformers.base import Transformer +from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader +from sdgx.utils import logger + + +class DiscreteTransformer(Transformer): + """ + DiscreteTransformer is an important component of sdgx, used to handle discrete columns. + + By default, DiscreteTransformer will perform one-hot encoding of discrete columns, and issue a warning message when dimensionality explosion occurs. + """ + + discrete_columns: list = None + """ + Record which columns are of discrete type. + """ + + encoders: dict = {} + + onehot_encoder_handle_unknown = "ignore" + + def fit(self, metadata: Metadata, tabular_data: DataLoader | pd.DataFrame): + """ + Fit method for the DiscreteTransformer. + """ + + logger.info("Fitting using DiscreteTransformer...") + + self.discrete_columns = metadata.get("discrete_columns") + + # no discrete columns + if len(self.discrete_columns) == 0: + logger.info("Fitting using DiscreteTransformer... Finished (No Columns).") + return + + # then, there are >= 1 discrete colums + for each_col in self.discrete_columns: + # fit each column + self._fit_column(each_col, tabular_data[[each_col]]) + + logger.info("Fitting using DiscreteTransformer... Finished.") + + return + + def _fit_column(self, column_name: str, column_data: pd.DataFrame): + """ + Fit every discrete columns in `_fit_column`. + + Args: + - column_data (pd.DataFrame): A dataframe containing a column. + - column_name: str: column name. + """ + + self.encoders[column_name] = OneHotEncoder( + handle_unknown=self.onehot_encoder_handle_unknown + ) + # fit the column data + self.encoders[column_name].fit(column_data) + + logger.info(f"Discrete column {column_name} fitted.") + + def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame: + """ + Convert method to handle discrete values in the input data. + """ + + logger.info("Converting data using DiscreteTransformer...") + + # TODO + # transform every discrete column into + if len(self.discrete_columns) == 0: + logger.info("Converting data using DiscreteTransformer... Finished (No column).") + return + + for each_col in self.discrete_columns: + new_onehot_column_set = self.encoders[each_col].transform(raw_data[[each_col]]) + # TODO 1- add new_onehot_column_set into the original dataframe + # TODO 2- delete the original column + logger.info(f"Column {each_col} converted.") + + logger.info("Converting data using DiscreteTransformer... Finished.") + + # return the result + return + + def _transform_column(self, column_name: str, column_data: pd.DataFrame | pd.Series): + """ + Transform every single discrete columns in `_transform_column`. + + Args: + - column_data (pd.DataFrame): A dataframe containing a column. + - column_name: str: column name. + + """ + pass + + def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame: + """ + Reverse_convert method for the transformer. + + + """ + + logger.info("Data reverse-converted by DiscreteTransformer.") + + return processed_data + + pass diff --git a/sdgx/data_processors/transformers/nan.py b/sdgx/data_processors/transformers/nan.py new file mode 100644 index 00000000..053062ab --- /dev/null +++ b/sdgx/data_processors/transformers/nan.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any + +from pandas import DataFrame + +from sdgx.data_models.metadata import Metadata +from sdgx.data_processors.extension import hookimpl +from sdgx.data_processors.transformers.base import Transformer +from sdgx.utils import logger + + +class NonValueTransformer(Transformer): + """ + Transformer class for handling missing values in data. + + This Transformer is mainly used as a reference for Transformer to facilitate developers to quickly understand the role of Transformer. + """ + + fill_na_value = 0 + + drop_na = False + + def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]): + """ + Fit method for the transformer. + + Does not require any action. + """ + logger.info("NonValueTransformer Fitted.") + + return + + def convert(self, raw_data: DataFrame) -> DataFrame: + """ + Convert method to handle missing values in the input data. + """ + + logger.info("Converting data using NonValueTransformer...") + + if self.drop_na: + res = raw_data.dropna() + else: + res = raw_data.fillna(value=self.fill_na_value) + + logger.info("Converting data using NonValueTransformer... Finished.") + + return res + + def reverse_convert(self, processed_data: DataFrame) -> DataFrame: + """ + Reverse_convert method for the transformer. + + Does not require any action. + """ + logger.info("Data reverse-converted by NonValueTransformer (No Action).") + + return processed_data + + pass + + +@hookimpl +def register(manager): + manager.register("NonValueTransformer", NonValueTransformer) diff --git a/sdgx/synthesizer.py b/sdgx/synthesizer.py index d17b34f5..9b2daab8 100644 --- a/sdgx/synthesizer.py +++ b/sdgx/synthesizer.py @@ -99,9 +99,9 @@ def __init__( self.dataloader = None # Init data processors - if not data_processors: - data_processors = [] self.data_processors_manager = DataProcessorManager() + if not data_processors: + data_processors = self.data_processors_manager.registed_default_processor_list self.data_processors = [ ( d @@ -288,8 +288,12 @@ def fit( self.metadata = metadata # Ensure update metadata logger.info("Fitting data processors...") + start_time = time.time() for d in self.data_processors: d.fit(metadata) + logger.info( + f"Fitted {len(self.data_processors)} data processors in {time.time() - start_time}s." + ) def chunk_generator() -> Generator[pd.DataFrame, None, None]: for chunk in self.dataloader.iter(): diff --git a/tests/data_models/inspector/test_personal.py b/tests/data_models/inspector/test_personal.py index c7d27490..9f417ac9 100644 --- a/tests/data_models/inspector/test_personal.py +++ b/tests/data_models/inspector/test_personal.py @@ -6,7 +6,10 @@ import pytest from faker import Faker +fake = Faker(locale="zh_CN") + from sdgx.data_models.inspectors.personal import ( + ChinaMainlandAddressInspector, ChinaMainlandIDInspector, ChinaMainlandMobilePhoneInspector, ChinaMainlandPostCode, @@ -14,8 +17,6 @@ EmailInspector, ) -fake = Faker(locale="zh_CN") - def generate_uniform_credit_code(): # generate china mainland 统一社会信用代码 for test @@ -194,5 +195,26 @@ def test_chn_uscc_inspector_generated_data(chn_personal_test_df: pd.DataFrame): assert inspector_USCC.pii is True +# CHN address +def test_chn_address_inspector_demo_data(raw_data): + inspector_CHN_Address = ChinaMainlandAddressInspector() + inspector_CHN_Address.fit(raw_data) + assert not inspector_CHN_Address.regex_columns + assert sorted(inspector_CHN_Address.inspect()["china_mainland_address_columns"]) == sorted([]) + assert inspector_CHN_Address.inspect_level == 30 + assert inspector_CHN_Address.pii is True + + +def test_chn_address_inspector_generated_data(chn_personal_test_df: pd.DataFrame): + inspector_CHN_Address = ChinaMainlandAddressInspector() + inspector_CHN_Address.fit(chn_personal_test_df) + # assert inspector_CHN_Address.regex_columns + assert sorted(inspector_CHN_Address.inspect()["china_mainland_address_columns"]) == sorted( + ["chn_address"] + ) + assert inspector_CHN_Address.inspect_level == 30 + assert inspector_CHN_Address.pii is True + + if __name__ == "__main__": pytest.main(["-vv", "-s", __file__])