Skip to content

Commit

Permalink
Implemented GroupBy.tail (#1949)
Browse files Browse the repository at this point in the history
This PR proposes `GroupBy.tail()` for `DataFrameGroupBy` and `SeriesGroupBy`.

```python
>>> df = ks.DataFrame({'a': [1, 1, 1, 1, 2, 2, 2, 3, 3, 3],
...                    'b': [2, 3, 1, 4, 6, 9, 8, 10, 7, 5],
...                    'c': [3, 5, 2, 5, 1, 2, 6, 4, 3, 6]},
...                   columns=['a', 'b', 'c'],
...                   index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6])
>>> df
    a   b  c
7   1   2  3
2   1   3  5
4   1   1  2
1   1   4  5
3   2   6  1
4   2   9  2
9   2   8  6
10  3  10  4
5   3   7  3
6   3   5  6

>>> df.groupby('a').tail(2).sort_index()
   a  b  c
1  1  4  5
4  1  1  2
4  2  9  2
5  3  7  3
6  3  5  6
9  2  8  6

>>> df.groupby('a')['b'].tail(2).sort_index()
1    4
4    1
4    9
5    7
6    5
9    8
Name: b, dtype: int64
```
  • Loading branch information
itholic authored Dec 10, 2020
1 parent 341fc42 commit ba02fa7
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 25 deletions.
117 changes: 94 additions & 23 deletions databricks/koalas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,46 @@ def ffill(self, limit=None) -> Union[DataFrame, Series]:

pad = ffill

def _limit(self, n: int, asc: bool):
"""
Private function for tail and head.
"""
kdf = self._kdf

if self._agg_columns_selected:
agg_columns = self._agg_columns
else:
agg_columns = [
kdf._kser_for(label)
for label in kdf._internal.column_labels
if label not in self._column_labels_to_exlcude
]

kdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply(
kdf, self._groupkeys, agg_columns,
)

groupkey_scols = [kdf._internal.spark_column_for(label) for label in groupkey_labels]

sdf = kdf._internal.spark_frame
tmp_col = verify_temp_column_name(sdf, "__row_number__")

# This part is handled differently depending on whether it is a tail or a head.
window = (
Window.partitionBy(groupkey_scols).orderBy(F.col(NATURAL_ORDER_COLUMN_NAME).asc())
if asc
else Window.partitionBy(groupkey_scols).orderBy(F.col(NATURAL_ORDER_COLUMN_NAME).desc())
)

sdf = (
sdf.withColumn(tmp_col, F.row_number().over(window))
.filter(F.col(tmp_col) <= n)
.drop(tmp_col)
)

internal = kdf._internal.with_new_sdf(sdf)
return DataFrame(internal).drop(groupkey_labels, axis=1)

def head(self, n=5) -> Union[DataFrame, Series]:
"""
Return first n rows of each group.
Expand Down Expand Up @@ -1838,34 +1878,60 @@ def head(self, n=5) -> Union[DataFrame, Series]:
10 10
Name: b, dtype: int64
"""
kdf = self._kdf
return self._limit(n, asc=True)

if self._agg_columns_selected:
agg_columns = self._agg_columns
else:
agg_columns = [
kdf._kser_for(label)
for label in kdf._internal.column_labels
if label not in self._column_labels_to_exlcude
]
def tail(self, n=5) -> Union[DataFrame, Series]:
"""
Return last n rows of each group.
kdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply(
kdf, self._groupkeys, agg_columns,
)
Similar to `.apply(lambda x: x.tail(n))`, but it returns a subset of rows from
the original DataFrame with original index and order preserved (`as_index` flag is ignored).
groupkey_scols = [kdf._internal.spark_column_for(label) for label in groupkey_labels]
Does not work for negative values of n.
sdf = kdf._internal.spark_frame
tmp_col = verify_temp_column_name(sdf, "__row_number__")
window = Window.partitionBy(groupkey_scols).orderBy(NATURAL_ORDER_COLUMN_NAME)
sdf = (
sdf.withColumn(tmp_col, F.row_number().over(window))
.filter(F.col(tmp_col) <= n)
.drop(tmp_col)
)
Returns
-------
DataFrame or Series
internal = kdf._internal.with_new_sdf(sdf)
return DataFrame(internal).drop(groupkey_labels, axis=1)
Examples
--------
>>> df = ks.DataFrame({'a': [1, 1, 1, 1, 2, 2, 2, 3, 3, 3],
... 'b': [2, 3, 1, 4, 6, 9, 8, 10, 7, 5],
... 'c': [3, 5, 2, 5, 1, 2, 6, 4, 3, 6]},
... columns=['a', 'b', 'c'],
... index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6])
>>> df
a b c
7 1 2 3
2 1 3 5
4 1 1 2
1 1 4 5
3 2 6 1
4 2 9 2
9 2 8 6
10 3 10 4
5 3 7 3
6 3 5 6
>>> df.groupby('a').tail(2).sort_index()
a b c
1 1 4 5
4 2 9 2
4 1 1 2
5 3 7 3
6 3 5 6
9 2 8 6
>>> df.groupby('a')['b'].tail(2).sort_index()
1 4
4 9
4 1
5 7
6 5
9 8
Name: b, dtype: int64
"""
return self._limit(n, asc=False)

def shift(self, periods=1, fill_value=None) -> Union[DataFrame, Series]:
"""
Expand Down Expand Up @@ -2702,6 +2768,11 @@ def head(self, n=5) -> Series:

head.__doc__ = GroupBy.head.__doc__

def tail(self, n=5) -> Series:
return first_series(super().tail(n)).rename(self._kser.name)

tail.__doc__ = GroupBy.tail.__doc__

def size(self) -> Series:
return super().size().rename(self._kser.name)

Expand Down
2 changes: 0 additions & 2 deletions databricks/koalas/missing/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class MissingPandasLikeDataFrameGroupBy(object):
prod = _unsupported_function("prod")
resample = _unsupported_function("resample")
sem = _unsupported_function("sem")
tail = _unsupported_function("tail")


class MissingPandasLikeSeriesGroupBy(object):
Expand Down Expand Up @@ -103,4 +102,3 @@ class MissingPandasLikeSeriesGroupBy(object):
prod = _unsupported_function("prod")
resample = _unsupported_function("resample")
sem = _unsupported_function("sem")
tail = _unsupported_function("tail")
122 changes: 122 additions & 0 deletions databricks/koalas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2608,3 +2608,125 @@ def test_get_group(self):
self.assertRaises(
ValueError, lambda: kdf.groupby([("B", "class"), ("A", "name")]).get_group("mammal")
)

def test_tail(self):
pdf = pd.DataFrame(
{
"a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3] * 3,
"b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5] * 3,
"c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6] * 3,
},
index=np.random.rand(10 * 3),
)
kdf = ks.from_pandas(pdf)

self.assert_eq(pdf.groupby("a").tail(2).sort_index(), kdf.groupby("a").tail(2).sort_index())
self.assert_eq(
pdf.groupby("a").tail(-2).sort_index(), kdf.groupby("a").tail(-2).sort_index()
)
self.assert_eq(
pdf.groupby("a").tail(100000).sort_index(), kdf.groupby("a").tail(100000).sort_index()
)

self.assert_eq(
pdf.groupby("a")["b"].tail(2).sort_index(), kdf.groupby("a")["b"].tail(2).sort_index()
)
self.assert_eq(
pdf.groupby("a")["b"].tail(-2).sort_index(), kdf.groupby("a")["b"].tail(-2).sort_index()
)
self.assert_eq(
pdf.groupby("a")["b"].tail(100000).sort_index(),
kdf.groupby("a")["b"].tail(100000).sort_index(),
)

self.assert_eq(
pdf.groupby("a")[["b"]].tail(2).sort_index(),
kdf.groupby("a")[["b"]].tail(2).sort_index(),
)
self.assert_eq(
pdf.groupby("a")[["b"]].tail(-2).sort_index(),
kdf.groupby("a")[["b"]].tail(-2).sort_index(),
)
self.assert_eq(
pdf.groupby("a")[["b"]].tail(100000).sort_index(),
kdf.groupby("a")[["b"]].tail(100000).sort_index(),
)

self.assert_eq(
pdf.groupby(pdf.a // 2).tail(2).sort_index(),
kdf.groupby(kdf.a // 2).tail(2).sort_index(),
)
self.assert_eq(
pdf.groupby(pdf.a // 2)["b"].tail(2).sort_index(),
kdf.groupby(kdf.a // 2)["b"].tail(2).sort_index(),
)
self.assert_eq(
pdf.groupby(pdf.a // 2)[["b"]].tail(2).sort_index(),
kdf.groupby(kdf.a // 2)[["b"]].tail(2).sort_index(),
)

self.assert_eq(
pdf.b.rename().groupby(pdf.a).tail(2).sort_index(),
kdf.b.rename().groupby(kdf.a).tail(2).sort_index(),
)
self.assert_eq(
pdf.b.groupby(pdf.a.rename()).tail(2).sort_index(),
kdf.b.groupby(kdf.a.rename()).tail(2).sort_index(),
)
self.assert_eq(
pdf.b.rename().groupby(pdf.a.rename()).tail(2).sort_index(),
kdf.b.rename().groupby(kdf.a.rename()).tail(2).sort_index(),
)

# multi-index
midx = pd.MultiIndex(
[["x", "y"], ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]],
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]],
)
pdf = pd.DataFrame(
{
"a": [1, 1, 1, 1, 2, 2, 2, 3, 3, 3],
"b": [2, 3, 1, 4, 6, 9, 8, 10, 7, 5],
"c": [3, 5, 2, 5, 1, 2, 6, 4, 3, 6],
},
columns=["a", "b", "c"],
index=midx,
)
kdf = ks.from_pandas(pdf)

self.assert_eq(pdf.groupby("a").tail(2).sort_index(), kdf.groupby("a").tail(2).sort_index())
self.assert_eq(
pdf.groupby("a").tail(-2).sort_index(), kdf.groupby("a").tail(-2).sort_index()
)
self.assert_eq(
pdf.groupby("a").tail(100000).sort_index(), kdf.groupby("a").tail(100000).sort_index()
)

self.assert_eq(
pdf.groupby("a")["b"].tail(2).sort_index(), kdf.groupby("a")["b"].tail(2).sort_index()
)
self.assert_eq(
pdf.groupby("a")["b"].tail(-2).sort_index(), kdf.groupby("a")["b"].tail(-2).sort_index()
)
self.assert_eq(
pdf.groupby("a")["b"].tail(100000).sort_index(),
kdf.groupby("a")["b"].tail(100000).sort_index(),
)

# multi-index columns
columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "c")])
pdf.columns = columns
kdf.columns = columns

self.assert_eq(
pdf.groupby(("x", "a")).tail(2).sort_index(),
kdf.groupby(("x", "a")).tail(2).sort_index(),
)
self.assert_eq(
pdf.groupby(("x", "a")).tail(-2).sort_index(),
kdf.groupby(("x", "a")).tail(-2).sort_index(),
)
self.assert_eq(
pdf.groupby(("x", "a")).tail(100000).sort_index(),
kdf.groupby(("x", "a")).tail(100000).sort_index(),
)
1 change: 1 addition & 0 deletions docs/source/reference/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Computations / Descriptive Stats
GroupBy.head
GroupBy.backfill
GroupBy.shift
GroupBy.tail

The following methods are available only for `DataFrameGroupBy` objects.

Expand Down

0 comments on commit ba02fa7

Please sign in to comment.