Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Config class and fix printing. #456

Merged
merged 6 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
get_default_transformer, get_transformer_instance, get_transformers_by_type)


class Config(dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have this class I'm wondering if it makes sense to move the field_transformers, _provided_field_transformers, field_sdtypes and _provided_field_sdtypes to this class. Then we could make the config an attribute of the HyperTransformer and add methods to this class for updating the config. For example, if the user sets something, it will update both dictionaries appropriately. Like a private update_field_transformers(self, user_provided=False) and update_field_sdtypes(self, user_provided=False) method that either only update the main dictionaries or both if necessary.

It's up to you if you want to do that extra work. I think the benefit would be that it would be less likely for anyone to forget to update or clear a provided dictionary when updating the main one, but this isn't in the scope of the issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened an issue to address this in another pr : #457

"""Config dict for ``HyperTransformer`` with a better representation."""

def __repr__(self):
"""Pretty print the dictionary."""
config = {
'sdtypes': self['sdtypes'],
'transformers': {k: repr(v) for k, v in self['transformers'].items()}
}
return json.dumps(config, indent=4)


class HyperTransformer:
"""HyperTransformer class.

Expand Down Expand Up @@ -53,8 +65,8 @@ class HyperTransformer:
# pylint: disable=too-many-instance-attributes

_DTYPES_TO_SDTYPES = {
'i': 'integer',
'f': 'float',
'i': 'numerical',
'f': 'numerical',
'O': 'categorical',
'b': 'boolean',
'M': 'datetime',
Expand Down Expand Up @@ -176,10 +188,10 @@ def get_config(self):
- sdtypes: A dictionary mapping column names to their ``sdtypes``.
- transformers: A dictionary mapping column names to their transformer instances.
"""
return {
return Config({
'sdtypes': self.field_sdtypes,
'transformers': self.field_transformers
}
})

def set_config(self, config):
"""Set the ``HyperTransformer`` configuration.
Expand Down Expand Up @@ -437,13 +449,13 @@ def detect_initial_config(self, data):
self._user_message('Detecting a new config from the data ... SUCCESS')
self._user_message('Setting the new config ... SUCCESS')

config = {
config = Config({
'sdtypes': self.field_sdtypes,
'transformers': {k: repr(v) for k, v in self.field_transformers.items()}
}
'transformers': self.field_transformers
})

self._user_message('Config:')
self._user_message(json.dumps(config, indent=4))
self._user_message(config)

def _get_next_transformer(self, output_field, output_sdtype, next_transformers):
next_transformer = None
Expand Down
4 changes: 1 addition & 3 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def get_transformer_name(transformer):

DEFAULT_TRANSFORMERS = {
'numerical': FloatFormatter(missing_value_replacement='mean'),
'integer': FloatFormatter(missing_value_replacement='mean'),
'float': FloatFormatter(missing_value_replacement='mean'),
'categorical': FrequencyEncoder(),
'boolean': BinaryEncoder(missing_value_replacement='mode'),
'datetime': UnixTimestampEncoder(missing_value_replacement='mean'),
Expand Down Expand Up @@ -162,7 +160,7 @@ def get_default_transformers():
defaults = deepcopy(DEFAULT_TRANSFORMERS)
for (sdtype, transformers) in transformers_by_type.items():
if sdtype not in defaults:
defaults[sdtype] = transformers[0]
defaults[sdtype] = transformers[0]()

return defaults

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_multiple_fits_with_set_config():
# Run
ht.detect_initial_config(data)
ht.set_config(config={
'sdtypes': {'integer': 'float'},
'sdtypes': {'integer': 'numerical'},
'transformers': {'bool': FrequencyEncoder}
})
ht.fit(data)
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ def test_detect_initial_config(self):
assert ht._provided_field_sdtypes == {}
assert ht._provided_field_transformers == {}
assert ht.field_sdtypes == {
'col1': 'float',
'col1': 'numerical',
'col2': 'categorical',
'col3': 'boolean',
'col4': 'datetime',
'col5': 'integer'
'col5': 'numerical'
}

field_transformers = {k: repr(v) for (k, v) in ht.field_transformers.items()}
Expand All @@ -464,11 +464,11 @@ def test_detect_initial_config(self):
'Config:',
'{',
' "sdtypes": {',
' "col1": "float",',
' "col1": "numerical",',
' "col2": "categorical",',
' "col3": "boolean",',
' "col4": "datetime",',
' "col5": "integer"',
' "col5": "numerical"',
' },',
' "transformers": {',
' "col1": "FloatFormatter(missing_value_replacement=\'mean\')",',
Expand Down