Skip to content

Commit

Permalink
Fix DataFrame.apply to support additional dtypes. (#2125)
Browse files Browse the repository at this point in the history
Fix `DataFrame.apply` to support additional dtypes.

After this, additional dtypes can be specified in the return type annotation of the UDFs for `DataFrame.apply`.

```py
>>> kdf = ks.DataFrame(
...     {"a": ["a", "b", "c", "a", "b", "c"], "b": ["b", "a", "c", "c", "b", "a"]}
... )
>>> dtype = pd.CategoricalDtype(categories=["a", "b", "c"])
>>> def categorize(ser) -> ks.Series[dtype]:
...     return ser.astype(dtype)
...
>>> applied = kdf.apply(categorize)
>>> applied
   a  b
0  a  b
1  b  a
2  c  c
3  a  c
4  b  b
5  c  a
>>> applied.dtypes
a    category
b    category
```

FYI: without the fix:

```py
>>> applied
   a  b
0  0  1
1  1  0
2  2  2
3  0  2
4  1  1
5  2  0
>>> applied.dtypes
a    int64
b    int64
dtype: object
```
  • Loading branch information
ueshin authored Mar 30, 2021
1 parent fdda825 commit fe9e594
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
8 changes: 7 additions & 1 deletion databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,7 @@ def apply_func(pdf):
self_applied.columns, [return_schema] * len(self_applied.columns)
)
return_schema = StructType([StructField(c, t) for c, t in fields_types])
data_dtypes = [cast(SeriesType, return_type).dtype] * len(self_applied.columns)
elif require_column_axis:
if axis != 1:
raise TypeError(
Expand All @@ -2596,11 +2597,13 @@ def apply_func(pdf):
"was %s" % return_sig
)
return_schema = cast(DataFrameType, return_type).spark_type
data_dtypes = cast(DataFrameType, return_type).dtypes
else:
# any axis is fine.
should_return_series = True
return_schema = cast(ScalarType, return_type).spark_type
return_schema = StructType([StructField(SPARK_DEFAULT_SERIES_NAME, return_schema)])
data_dtypes = [cast(ScalarType, return_type).dtype]
column_labels = [None]

if should_use_map_in_pandas:
Expand All @@ -2621,7 +2624,10 @@ def apply_func(pdf):

# Otherwise, it loses index.
internal = InternalFrame(
spark_frame=sdf, index_spark_columns=None, column_labels=column_labels
spark_frame=sdf,
index_spark_columns=None,
column_labels=column_labels,
data_dtypes=data_dtypes,
)

result = DataFrame(internal) # type: "DataFrame"
Expand Down
29 changes: 28 additions & 1 deletion databricks/koalas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def kdf(self):

@property
def df_pair(self):
return (self.pdf, self.kdf)
return self.pdf, self.kdf

def test_categorical_frame(self):
pdf, kdf = self.df_pair
Expand Down Expand Up @@ -106,6 +106,33 @@ def test_factorize(self):
self.assert_eq(kcodes.tolist(), pcodes.tolist())
self.assert_eq(kuniques, puniques)

def test_frame_apply(self):
pdf, kdf = self.df_pair

self.assert_eq(kdf.apply(lambda x: x).sort_index(), pdf.apply(lambda x: x).sort_index())
self.assert_eq(
kdf.apply(lambda x: x, axis=1).sort_index(), pdf.apply(lambda x: x, axis=1).sort_index()
)

def test_frame_apply_without_shortcut(self):
with ks.option_context("compute.shortcut_limit", 0):
self.test_frame_apply()

pdf = pd.DataFrame(
{"a": ["a", "b", "c", "a", "b", "c"], "b": ["b", "a", "c", "c", "b", "a"]}
)
kdf = ks.from_pandas(pdf)

dtype = CategoricalDtype(categories=["a", "b", "c"])

def categorize(ser) -> ks.Series[dtype]:
return ser.astype(dtype)

self.assert_eq(
kdf.apply(categorize).sort_values(["a", "b"]).reset_index(drop=True),
pdf.apply(categorize).sort_values(["a", "b"]).reset_index(drop=True),
)

def test_groupby_apply(self):
pdf, kdf = self.df_pair

Expand Down

0 comments on commit fe9e594

Please sign in to comment.