Skip to content

Commit 9582ccc

Browse files
authored
feat(warnings): convert static strings into enums (#56)
* created category enum with all the categories extracted * created test enum with all the test strings extracted * convert priority to enum value * created StringEnum with `_missing_` validation
1 parent 9e1de00 commit 9582ccc

File tree

11 files changed

+182
-52
lines changed

11 files changed

+182
-52
lines changed

src/ydata_quality/bias_fairness/engine.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pandas import DataFrame, Series
88
from dython.nominal import compute_associations
99

10+
from src.ydata_quality.core.warnings import Priority
11+
1012
from ..core import QualityEngine, QualityWarning
1113
from ..utils.correlations import filter_associations
1214
from ..utils.modelling import (baseline_performance,
@@ -58,7 +60,8 @@ def proxy_identification(self, th=0.5):
5860
if len(corrs) > 0:
5961
self.store_warning(
6062
QualityWarning(
61-
test='Proxy Identification', category='Bias&Fairness', priority=2, data=corrs,
63+
test=QualityWarning.Test.PROXY_IDENTIFICATION,
64+
category=QualityWarning.Category.BIAS_FAIRNESS, priority=Priority.P2, data=corrs,
6265
description=f"Found {len(corrs)} feature pairs of correlation "
6366
f"to sensitive attributes with values higher than defined threshold ({th})."
6467
))
@@ -80,7 +83,9 @@ def sensitive_predictability(self, th=0.5, adjusted_metric=True):
8083
if len(high_perfs) > 0:
8184
self.store_warning(
8285
QualityWarning(
83-
test='Sensitive Attribute Predictability', category='Bias&Fairness', priority=3, data=high_perfs,
86+
test=QualityWarning.Test.SENSITIVE_ATTRIBUTE_PREDICTABILITY,
87+
category=QualityWarning.Category.BIAS_FAIRNESS,
88+
priority=Priority.P3, data=high_perfs,
8489
description=f"Found {len(high_perfs)} sensitive attribute(s) with high predictability performance"
8590
f" (greater than {th})."
8691
)
@@ -124,8 +129,9 @@ def sensitive_representativity(self, min_pct: float = 0.01):
124129
if len(low_dist) > 0:
125130
self.store_warning(
126131
QualityWarning(
127-
test='Sensitive Attribute Representativity', category='Bias&Fairness', priority=2,
128-
data=low_dist, description=f"Found {len(low_dist)} values of '{cat}' \
132+
test=QualityWarning.Test.SENSITIVE_ATTRIBUTE_REPRESENTATIVITY,
133+
category=QualityWarning.Category.BIAS_FAIRNESS, priority=Priority.P2, data=low_dist,
134+
description=f"Found {len(low_dist)} values of '{cat}' \
129135
sensitive attribute with low representativity in the dataset (below {min_pct*100:.2f}%)."
130136
)
131137
)

src/ydata_quality/core/data_quality.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,14 @@ def _clean_warnings(self):
135135
self._warnings = sorted(list(set(self._warnings))) # Sort unique warnings by priority
136136

137137
def get_warnings(self,
138-
category: Optional[str] = None,
139-
test: Optional[str] = None,
138+
category: Optional[Union[QualityWarning.Category, str]] = None,
139+
test: Optional[Union[QualityWarning.Test, str]] = None,
140140
priority: Optional[Priority] = None) -> List[QualityWarning]:
141141
"Retrieves warnings filtered by their properties."
142+
143+
category = QualityWarning.Category(category) if category is not None else None
144+
test = QualityWarning.Test(test) if test is not None else None
145+
142146
self._store_warnings()
143147
self._clean_warnings()
144148
filtered = [w for w in self._warnings if w.category == category] if category else self._warnings

src/ydata_quality/core/warnings.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pydantic import BaseModel
88

9-
from ..utils.enum import OrderedEnum
9+
from ..utils.enum import OrderedEnum, StringEnum
1010

1111

1212
# pylint: disable=too-few-public-methods
@@ -67,8 +67,62 @@ class QualityWarning(BaseModel):
6767
data: sample data
6868
"""
6969

70-
category: str
71-
test: str
70+
class Category(StringEnum):
71+
BIAS_FAIRNESS = "BIAS&FAIRNESS"
72+
DATA_EXPECTATIONS = "DATA EXPECTATIONS"
73+
DATA_RELATIONS = "DATA RELATIONS"
74+
DUPLICATES = "DUPLICATES"
75+
ERRONEOUS_DATA = "ERRONEOUS DATA"
76+
LABELS = "LABELS"
77+
MISSINGS = "MISSINGS"
78+
SAMPLING = "SAMPLING"
79+
80+
class Test(StringEnum):
81+
# BIAS&FAIRNESS
82+
PROXY_IDENTIFICATION = "PROXY IDENTIFICATION"
83+
SENSITIVE_ATTRIBUTE_PREDICTABILITY = "SENSITIVE ATTRIBUTE PREDICTABILITY"
84+
SENSITIVE_ATTRIBUTE_REPRESENTATIVITY = "SENSITIVE ATTRIBUTE REPRESENTATIVITY"
85+
86+
# DATA EXPECTATIONS
87+
COVERAGE_FRACTION = "COVERAGE FRACTION"
88+
EXPECTATION_ASSESSMENT_VALUE_BETWEEN = "EXPECTATION ASSESSMENT - VALUE BETWEEN"
89+
OVERALL_ASSESSMENT = "OVERALL ASSESSMENT"
90+
91+
# DATA RELATIONS
92+
COLLIDER_CORRELATIONS = "COLLIDER CORRELATIONS"
93+
CONFOUNDED_CORRELATIONS = "CONFOUNDED CORRELATIONS"
94+
HIGH_COLLINEARITY_CATEGORICAL = "HIGH COLLINEARITY - CATEGORICAL"
95+
HIGH_COLLINEARITY_NUMERICAL = "HIGH COLLINEARITY - NUMERICAL"
96+
97+
# DUPLICATES
98+
DUPLICATE_COLUMNS = "DUPLICATE COLUMNS"
99+
ENTITY_DUPLICATES = "ENTITY DUPLICATES"
100+
EXACT_DUPLICATES = "EXACT DUPLICATES"
101+
102+
# ERRONEOUS DATA
103+
FLATLINES = "FLATLINES"
104+
PREDEFINED_ERRONEOUS_DATA = "PREDEFINED ERRONEOUS DATA"
105+
106+
# LABELS
107+
FEW_LABELS = "FEW LABELS"
108+
MISSING_LABELS = "MISSING LABELS"
109+
ONE_REST_PERFORMANCE = "ONE VS REST PERFORMANCE"
110+
OUTLIER_DETECTION = "OUTLIER DETECTION"
111+
TEST_NORMALITY = "TEST NORMALITY"
112+
UNBALANCED_CLASSES = "UNBALANCED CLASSES"
113+
114+
# MISSINGS
115+
HIGH_MISSINGS = "HIGH MISSINGS"
116+
HIGH_MISSING_CORRELATIONS = "HIGH MISSING CORRELATIONS"
117+
MISSINGNESS_PREDICTION = "MISSINGNESS PREDICTION"
118+
119+
# SAMPLING
120+
CONCEPT_DRIFT = "CONCEPT DRIFT"
121+
SAMPLE_COVARIATE_DRIFT = "SAMPLE COVARIATE DRIFT"
122+
SAMPLE_LABEL_DRIFT = "SAMPLE LABEL DRIFT"
123+
124+
category: Category
125+
test: Test
72126
description: str
73127
priority: Priority
74128
data: Any = None
@@ -78,7 +132,7 @@ class QualityWarning(BaseModel):
78132
#########################
79133
def __str__(self):
80134
return f"{WarningStyling.PRIORITIES[self.priority.value]}*{WarningStyling.ENDC} {WarningStyling.BOLD}\
81-
[{self.category.upper()}{WarningStyling.ENDC} - {WarningStyling.UNDERLINE}{self.test.upper()}]{WarningStyling.ENDC} \
135+
[{self.category.value}{WarningStyling.ENDC} - {WarningStyling.UNDERLINE}{self.test.value}]{WarningStyling.ENDC} \
82136
{self.description}"
83137

84138
########################

src/ydata_quality/data_expectations/engine.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pandas import DataFrame
77
from numpy import argmin
88

9+
from src.ydata_quality.core.warnings import Priority
10+
911
from ..core import QualityEngine, QualityWarning
1012
from ..utils.auxiliary import test_load_json_path
1113
from ..utils.logger import NAME, get_logger
@@ -55,7 +57,8 @@ def __between_value_error(self, expectation_summary: dict) -> tuple:
5557
bound of the expected range."
5658
self.store_warning(
5759
QualityWarning(
58-
test='Expectation assessment - Value Between', category='Data Expectations', priority=3,
60+
test=QualityWarning.Test.EXPECTATION_ASSESSMENT_VALUE_BETWEEN,
61+
category=QualityWarning.Category.DATA_EXPECTATIONS, priority=Priority.P3,
5962
data=(range_deviations, bound_deviations),
6063
description=f"Column {column_name} - The observed value is outside of the expected range."
6164
+ (range_deviation_string if range_deviations else "")
@@ -122,7 +125,8 @@ def _coverage_fraction(self, results_json_path: str, df: DataFrame, minimum_cove
122125
if coverage_fraction < minimum_coverage:
123126
self.store_warning(
124127
QualityWarning(
125-
test='Coverage Fraction', category='Data Expectations', priority=2,
128+
test=QualityWarning.Test.COVERAGE_FRACTION,
129+
category=QualityWarning.Category.DATA_EXPECTATIONS, priority=Priority.P2,
126130
data={'Columns not covered': df_column_set.difference(column_coverage)},
127131
description=f"The provided DataFrame has a total expectation coverage of {coverage_fraction:.0%} \
128132
of its columns, which is below the expected coverage of {minimum_coverage:.0%}."
@@ -147,7 +151,8 @@ def _overall_assessment(self, results_json_path: str, error_tol: int = 0,
147151
if results_summary['OVERALL']['expectation_count'] - results_summary['OVERALL']['total_successes'] > error_tol:
148152
self.store_warning(
149153
QualityWarning(
150-
test='Overall Assessment', category='Data Expectations', priority=2,
154+
test=QualityWarning.Test.OVERALL_ASSESSMENT,
155+
category=QualityWarning.Category.DATA_EXPECTATIONS, priority=Priority.P2,
151156
data={'Failed expectation indexes': failed_expectation_ids},
152157
description=f"{len(failed_expectation_ids)} expectations have failed, which is more than the \
153158
implied absolute threshold of {int(error_tol)} failed expectations."

src/ydata_quality/data_relations/engine.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pandas import DataFrame
77
from numpy import ones, tril, argwhere
88

9+
from src.ydata_quality.core.warnings import Priority
10+
911
from ..core import QualityEngine, QualityWarning
1012
from ..utils.auxiliary import infer_dtypes, standard_normalize
1113
from ..utils.correlations import (chi2_collinearity, correlation_matrix,
@@ -117,9 +119,11 @@ def _confounder_detection(self, corr_mat: DataFrame, par_corr_mat: DataFrame,
117119
mask[par_corr_mat.abs() > corr_th] = False # Drop pairs with correlation after controling all other covariates
118120
confounded_pairs = [(corr_mat.index[i], corr_mat.columns[j]) for i, j in argwhere(mask)]
119121
if len(confounded_pairs) > 0:
120-
self.store_warning(QualityWarning(
121-
test='Confounded correlations', category='Data Relations', priority=2, data=confounded_pairs,
122-
description=f"""
122+
self.store_warning(
123+
QualityWarning(
124+
test=QualityWarning.Test.CONFOUNDED_CORRELATIONS, category=QualityWarning.Category.DATA_RELATIONS,
125+
priority=Priority.P2, data=confounded_pairs,
126+
description=f"""
123127
Found {len(confounded_pairs)} independently correlated variable pairs that disappeared after controling\
124128
for the remaining variables. This is an indicator of potential confounder effects in the dataset."""))
125129
return confounded_pairs
@@ -138,9 +142,11 @@ def _collider_detection(self, corr_mat: DataFrame, par_corr_mat: DataFrame,
138142
mask[par_corr_mat.abs() <= corr_th] = False # Drop pairs with correlation after controling all other covariates
139143
colliding_pairs = [(corr_mat.index[i], corr_mat.columns[j]) for i, j in argwhere(mask)]
140144
if len(colliding_pairs) > 0:
141-
self.store_warning(QualityWarning(
142-
test='Collider correlations', category='Data Relations', priority=2, data=colliding_pairs,
143-
description=f"Found {len(colliding_pairs)} independently uncorrelated variable pairs that showed \
145+
self.store_warning(
146+
QualityWarning(
147+
test=QualityWarning.Test.COLLIDER_CORRELATIONS, category=QualityWarning.category.DATA_RELATIONS,
148+
priority=Priority.P2, data=colliding_pairs,
149+
description=f"Found {len(colliding_pairs)} independently uncorrelated variable pairs that showed \
144150
correlation after controling for the remaining variables. \
145151
This is an indicator of potential colliding bias with other covariates."))
146152
return colliding_pairs
@@ -192,18 +198,22 @@ def _high_collinearity_detection(self, df: DataFrame, dtypes: dict, label: str =
192198
['Adjusted Chi2'].mean()) for c in unique_cats]
193199
cat_coll_scores = [c[0] for c in sorted(cat_coll_scores, key=lambda x: x[1], reverse=True)]
194200
if len(inflated) > 0:
195-
self.store_warning(QualityWarning(
196-
test='High Collinearity - Numerical', category='Data Relations', priority=2, data=inflated,
197-
description=f"""Found {len(inflated)} numerical variables with high Variance Inflation Factor \
201+
self.store_warning(
202+
QualityWarning(
203+
test=QualityWarning.Test.HIGH_COLLINEARITY_NUMERICAL,
204+
category=QualityWarning.Category.DATA_RELATIONS, priority=Priority.P2, data=inflated,
205+
description=f"""Found {len(inflated)} numerical variables with high Variance Inflation Factor \
198206
(VIF>{vif_th:.1f}). The variables listed in results are highly collinear with other variables in the dataset. \
199207
These will make model explainability harder and potentially give way to issues like overfitting.\
200208
Depending on your end goal you might want to remove the highest VIF variables."""))
201209
if len(cat_coll_scores) > 0:
202210
# TODO: Merge warning messages (make one warning for the whole test,
203211
# summarizing findings from the numerical and categorical vars)
204-
self.store_warning(QualityWarning(
205-
test='High Collinearity - Categorical', category='Data Relations', priority=2, data=chi2_tests,
206-
description=f"""Found {len(cat_coll_scores)} categorical variables with significant collinearity \
212+
self.store_warning(
213+
QualityWarning(
214+
test=QualityWarning.Test.HIGH_COLLINEARITY_CATEGORICAL,
215+
category=QualityWarning.Category.DATA_RELATIONS, priority=Priority.P2, data=chi2_tests,
216+
description=f"""Found {len(cat_coll_scores)} categorical variables with significant collinearity \
207217
(p-value < {p_th}). The variables listed in results are highly collinear with other variables \
208218
in the dataset and sorted descending according to propensity. These will make model explainability \
209219
harder and potentially give way to issues like overfitting.Depending on your end goal you might want \

src/ydata_quality/drift/engine.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,14 +288,16 @@ def sample_covariate_drift(self, p_thresh: float = 0.05) -> DataFrame:
288288
if n_drifted_feats > 0:
289289
self.store_warning(
290290
QualityWarning(
291-
test='Sample covariate drift', category='Sampling', priority=2, data=test_summary,
291+
test=QualityWarning.Test.SAMPLE_COVARIATE_DRIFT, category=QualityWarning.Category.SAMPLING,
292+
priority=2, data=test_summary,
292293
description=f"""{n_drifted_feats} features accused drift in the sample test. The covariates \
293294
of the test sample do not appear to be representative of the reference sample."""
294295
))
295296
elif n_invalid_tests > 0:
296297
self.store_warning(
297298
QualityWarning(
298-
test='Sample covariate drift', category='Sampling', priority=3, data=test_summary,
299+
test=QualityWarning.Test.SAMPLE_COVARIATE_DRIFT, category=QualityWarning.Category.SAMPLING,
300+
priority=3, data=test_summary,
299301
description=f"""There were {n_invalid_tests} invalid tests found. This is likely due to a small \
300302
test sample size. The data summary should be analyzed before considering the test conclusive."""
301303
))
@@ -323,14 +325,16 @@ def sample_label_drift(self, p_thresh: float = 0.05) -> Series:
323325
if test_summary['Verdict'] == 'Drift':
324326
self.store_warning(
325327
QualityWarning(
326-
test='Sample label drift', category='Sampling', priority=2, data=test_summary,
328+
test=QualityWarning.Test.SAMPLE_LABEL_DRIFT, category=QualityWarning.Category.SAMPLING,
329+
priority=2, data=test_summary,
327330
description=f"The label accused drift in the sample test with a p-test of {p_val:.4f}, which is \
328331
under the threshold {p_thresh:.2f}. The test sample labels do not appear to be representative of the reference sample."
329332
))
330333
elif test_summary['Verdict'] == 'Invalid test':
331334
self.store_warning(
332335
QualityWarning(
333-
test='Sample label drift', category='Sampling', priority=3, data=test_summary,
336+
test=QualityWarning.Test.SAMPLE_LABEL_DRIFT, category=QualityWarning.Category.SAMPLING,
337+
priority=3, data=test_summary,
334338
description="The test was invalid. This is likely due to a small test sample size."
335339
))
336340
else:
@@ -363,15 +367,17 @@ def sample_concept_drift(self, p_thresh: float = 0.05) -> Series:
363367
if test_summary['Verdict'] == 'Drift':
364368
self.store_warning(
365369
QualityWarning(
366-
test='Concept drift', category='Sampling', priority=2, data=test_summary,
370+
test=QualityWarning.Test.CONCEPT_DRIFT, category=QualityWarning.Category.SAMPLING,
371+
priority=2, data=test_summary,
367372
description=f"There was concept drift detected with a p-test of {p_val:.4f}, which is under the \
368373
threshold {p_thresh:.2f}. The model's predicted labels for the test sample do not appear to be representative of the \
369374
distribution of labels predicted for the reference sample."
370375
))
371376
elif test_summary['Verdict'] == 'Invalid test':
372377
self.store_warning(
373378
QualityWarning(
374-
test='Concept drift', category='Sampling', priority=3, data=test_summary,
379+
test=QualityWarning.Test.CONCEPT_DRIFT, category=QualityWarning.Category.SAMPLING,
380+
priority=3, data=test_summary,
375381
description="The test was invalid. This is likely due to a small test sample size."
376382
))
377383
else:

src/ydata_quality/duplicates/engine.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from pandas import DataFrame
88

9+
from src.ydata_quality.core.warnings import Priority
10+
911
from ..core import QualityEngine, QualityWarning
1012
from ..utils.auxiliary import find_duplicate_columns
1113

@@ -69,7 +71,8 @@ def exact_duplicates(self):
6971
if len(dups) > 0:
7072
self.store_warning(
7173
QualityWarning(
72-
test='Exact Duplicates', category='Duplicates', priority=2, data=dups,
74+
test=QualityWarning.Test.EXACT_DUPLICATES, category=QualityWarning.Category.DUPLICATES,
75+
priority=Priority.P2, data=dups,
7376
description=f"Found {len(dups)} instances with exact duplicate feature values."
7477
))
7578
else:
@@ -84,7 +87,7 @@ def __provided_entity_dups(self, entity: Optional[Union[str, List[str]]] = None)
8487
if len(dups) > 0: # if we have any duplicates
8588
self.store_warning(
8689
QualityWarning(
87-
test='Entity Duplicates', category='Duplicates', priority=2, data=dups,
90+
test='Entity Duplicates', category='Duplicates', priority=Priority.P2, data=dups,
8891
description=f"Found {len(dups)} duplicates after grouping by entities."
8992
))
9093
if isinstance(entity, str):
@@ -124,7 +127,8 @@ def duplicate_columns(self):
124127
if cols_with_dups > 0:
125128
self.store_warning(
126129
QualityWarning(
127-
test='Duplicate Columns', category='Duplicates', priority=1, data=dups,
130+
test=QualityWarning.Test.DUPLICATE_COLUMNS, category=QualityWarning.Category.DUPLICATES,
131+
priority=Priority.P1, data=dups,
128132
description=f"Found {cols_with_dups} columns with exactly the same feature values as other columns."
129133
)
130134
)

0 commit comments

Comments
 (0)