Skip to content

Commit

Permalink
One-hot-encode categorical column (#22)
Browse files Browse the repository at this point in the history
One-hot-encode categorical column
  • Loading branch information
michcio1234 authored Nov 9, 2017
2 parents 0e0ec76 + 5c3ceb3 commit 05867f6
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 9 deletions.
25 changes: 21 additions & 4 deletions sparsity/dask/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


def one_hot_encode(ddf, column=None, categories=None, index_col=None,
order=None, prefixes=False):
order=None, prefixes=False,
ignore_cat_order_mismatch=False):
"""
Sparse one hot encoding of dask.DataFrame.
Expand All @@ -21,8 +22,11 @@ def one_hot_encode(ddf, column=None, categories=None, index_col=None,
ddf: dask.DataFrame
e.g. the clickstream
categories: dict
Maps column name -> iterable of possible category values.
See description of `order`.
Maps ``column name`` -> ``iterable of possible category values``.
Can be also ``column name`` -> ``None`` if this column is already
of categorical dtype.
This argument decides which column(s) will be encoded.
See description of `order` and `ignore_cat_order_mismatch`.
index_col: str | iterable
which columns to use as index
order: iterable
Expand All @@ -46,6 +50,16 @@ def one_hot_encode(ddf, column=None, categories=None, index_col=None,
[col1_cat11, col1_cat12, col2_cat21, col2_cat22, ...].
column: DEPRECATED
Kept only for backward compatibility.
ignore_cat_order_mismatch: bool
If a column being one-hot encoded is of categorical dtype, it has
its categories already predefined, so we don't need to explicitly pass
them in `categories` argument (see this argument's description).
However, if we pass them, they may be different than ones defined in
column.cat.categories. In such a situation, a ValueError will be
raised. However, if only orders of categories are different (but sets
of elements are same), you may specify ignore_cat_order_mismatch=True
to suppress this error. In such a situation, column's predefined
categories will be used.
Returns
-------
Expand All @@ -71,14 +85,17 @@ def one_hot_encode(ddf, column=None, categories=None, index_col=None,
columns = sparse_one_hot(ddf._meta,
categories=categories,
index_col=index_col,
prefixes=prefixes).columns
prefixes=prefixes,
ignore_cat_order_mismatch=ignore_cat_order_mismatch
).columns
meta = sp.SparseFrame(np.array([]), columns=columns,
index=idx_meta)

dsf = ddf.map_partitions(sparse_one_hot,
categories=categories,
index_col=index_col,
prefixes=prefixes,
ignore_cat_order_mismatch=ignore_cat_order_mismatch,
meta=object)

return SparseFrame(dsf.dask, dsf._name, meta, dsf.divisions)
41 changes: 37 additions & 4 deletions sparsity/sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ def _create_group_matrix(group_idx, dtype='f8'):


def sparse_one_hot(df, column=None, categories=None, dtype='f8',
index_col=None, order=None, prefixes=False):
index_col=None, order=None, prefixes=False,
ignore_cat_order_mismatch=False):
"""
One-hot encode specified columns of a pandas.DataFrame.
Returns a SparseFrame.
Expand All @@ -673,7 +674,10 @@ def sparse_one_hot(df, column=None, categories=None, dtype='f8',
for column, column_cat in categories.items():
if isinstance(column_cat, str):
column_cat = _just_read_array(column_cat)
cols, csr = _one_hot_series_csr(column_cat, dtype, df[column])
cols, csr = _one_hot_series_csr(
column_cat, dtype, df[column],
ignore_cat_order_mismatch=ignore_cat_order_mismatch
)
if prefixes:
cols = list(map(lambda x: '{}_{}'.format(column, x), cols))
new_cols.extend(cols)
Expand All @@ -692,9 +696,13 @@ def sparse_one_hot(df, column=None, categories=None, dtype='f8',
return SparseFrame(new_data, index=new_index, columns=new_cols)


def _one_hot_series_csr(categories, dtype, oh_col):
def _one_hot_series_csr(categories, dtype, oh_col,
ignore_cat_order_mismatch=False):
if types.is_categorical_dtype(oh_col):
cat = oh_col
cat = oh_col.cat
_check_categories_order(cat.categories, categories, oh_col.name,
ignore_cat_order_mismatch)

else:
s = oh_col
cat = pd.Categorical(s, np.asarray(categories))
Expand All @@ -712,3 +720,28 @@ def _one_hot_series_csr(categories, dtype, oh_col):
shape=(n_samples, n_features),
dtype=dtype).tocsr()
return cat.categories.values, data


def _check_categories_order(categories1, categories2, categorical_column_name,
ignore_cat_order_mismatch):
"""Check if two lists of categories differ. If they have different
elements, raise an exception. If they differ only by order of elements,
raise an issue unless ignore_cat_order_mismatch is set."""

if categories2 is None or list(categories2) == list(categories1):
return

if set(categories2) == set(categories1):
mismatch_type = 'order'
else:
mismatch_type = 'set'

if mismatch_type == 'set' or not ignore_cat_order_mismatch:
raise ValueError(
"Got categorical column {column_name} whose categories "
"{mismatch_type} doesn't match categories {mismatch_type} "
"given as argument to this function.".format(
column_name=categorical_column_name,
mismatch_type=mismatch_type
)
)
16 changes: 16 additions & 0 deletions sparsity/test/test_dask_sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def test_one_hot_no_order(clickstream):
assert sorted(sf.columns) == list('ABCDEFGHIJ')


def test_one_hot_no_order_categorical(clickstream):
clickstream['other_categorical'] = clickstream['other_categorical'] \
.astype('category')
ddf = dd.from_pandas(clickstream, npartitions=10)
dsf = one_hot_encode(ddf,
categories={'page_id': list('ABCDE'),
'other_categorical': list('FGHIJ')},
index_col=['index', 'id'])
assert dsf._meta.empty
assert sorted(dsf.columns) == list('ABCDEFGHIJ')
sf = dsf.compute()
assert sf.shape == (100, 10)
assert isinstance(sf.index, pd.MultiIndex)
assert sorted(sf.columns) == list('ABCDEFGHIJ')


def test_one_hot_prefixes(clickstream):
ddf = dd.from_pandas(clickstream, npartitions=10)
dsf = one_hot_encode(ddf,
Expand Down
96 changes: 95 additions & 1 deletion sparsity/test/test_sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ def mock_s3_fs(bucket, data=None):
# 2017 starts with a sunday
@pytest.fixture()
def sampledata():
def gendata(n):
def gendata(n, categorical=False):
sample_data = pd.DataFrame(
dict(date=pd.date_range("2017-01-01", periods=n)))
sample_data["weekday"] = sample_data.date.dt.weekday_name
sample_data["weekday_abbr"] = sample_data.weekday.apply(
lambda x: x[:3])

if categorical:
sample_data['weekday'] = sample_data['weekday'].astype('category')
sample_data['weekday_abbr'] = sample_data['weekday_abbr'] \
.astype('category')

sample_data["id"] = np.tile(np.arange(7), len(sample_data) // 7 + 1)[
:len(sample_data)]
return sample_data
Expand Down Expand Up @@ -468,6 +474,94 @@ def test_csr_one_hot_series(sampledata, weekdays, weekdays_abbr):
assert all(sparse_frame.columns == (weekdays + weekdays_abbr))


def test_csr_one_hot_series_categorical_same_order(sampledata, weekdays,
weekdays_abbr):
correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {'weekday': data['weekday'].cat.categories.tolist(),
'weekday_abbr': data['weekday_abbr'].cat.categories.tolist()}

sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=False)

res = sparse_frame.groupby_sum(np.tile(np.arange(7), 7)) \
.todense()[weekdays + weekdays_abbr].values
assert np.all(res == correct)
assert set(sparse_frame.columns) == set(weekdays + weekdays_abbr)


def test_csr_one_hot_series_categorical_different_order(sampledata, weekdays,
weekdays_abbr):
correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {
'weekday': data['weekday'].cat.categories.tolist()[::-1],
'weekday_abbr': data['weekday_abbr'].cat.categories.tolist()[::-1]
}

with pytest.raises(ValueError):
sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=False)


def test_csr_one_hot_series_categorical_different_order_ignore(
sampledata, weekdays, weekdays_abbr):

correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {
'weekday': data['weekday'].cat.categories.tolist()[::-1],
'weekday_abbr': data['weekday_abbr'].cat.categories.tolist()[::-1]
}

sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=True)

res = sparse_frame.groupby_sum(np.tile(np.arange(7), 7)) \
.todense()[weekdays + weekdays_abbr].values
assert np.all(res == correct)
assert set(sparse_frame.columns) == set(weekdays + weekdays_abbr)


def test_csr_one_hot_series_categorical_no_categories(
sampledata, weekdays, weekdays_abbr):

correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {
'weekday': None,
'weekday_abbr': None
}

sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=True)

res = sparse_frame.groupby_sum(np.tile(np.arange(7), 7)) \
.todense()[weekdays + weekdays_abbr].values
assert np.all(res == correct)
assert set(sparse_frame.columns) == set(weekdays + weekdays_abbr)


def test_csr_one_hot_series_other_order(sampledata, weekdays, weekdays_abbr):

categories = {'weekday': weekdays,
Expand Down

0 comments on commit 05867f6

Please sign in to comment.