Skip to content

Commit

Permalink
[SPARK-36818][PYTHON] Fix filtering a Series by a boolean Series
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Fix filtering a Series (without a name) by a boolean Series.

### Why are the changes needed?
A bugfix. The issue is raised as databricks/koalas#2199.

### Does this PR introduce _any_ user-facing change?
Yes.

#### From
```py
>>> psser = ps.Series([0, 1, 2, 3, 4])
>>> ps.set_option('compute.ops_on_diff_frames', True)
>>> psser.loc[ps.Series([True, True, True, False, False])]
Traceback (most recent call last):
...
KeyError: 'none key'

```

#### To
```py
>>> psser = ps.Series([0, 1, 2, 3, 4])
>>> ps.set_option('compute.ops_on_diff_frames', True)
>>> psser.loc[ps.Series([True, True, True, False, False])]
0    0
1    1
2    2
dtype: int64
```

### How was this patch tested?
Unit test.

Closes #34061 from xinrong-databricks/filter_series.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
xinrong-meng authored and ueshin committed Sep 22, 2021
1 parent a7cbe69 commit 6a5ee02
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/pyspark/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import Label, Name, Scalar
from pyspark.pandas.internal import (
DEFAULT_SERIES_NAME,
InternalField,
InternalFrame,
NATURAL_ORDER_COLUMN_NAME,
Expand Down Expand Up @@ -435,11 +436,12 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]:

if self._is_series:
if isinstance(key, Series) and not same_anchor(key, self._psdf_or_psser):
psdf = self._psdf_or_psser.to_frame()
name = self._psdf_or_psser.name or DEFAULT_SERIES_NAME
psdf = self._psdf_or_psser.to_frame(name)
temp_col = verify_temp_column_name(psdf, "__temp_col__")

psdf[temp_col] = key
return type(self)(psdf[self._psdf_or_psser.name])[psdf[temp_col]]
return type(self)(psdf[name].rename(self._psdf_or_psser.name))[psdf[temp_col]]

cond, limit, remaining_index = self._select_rows(key)
if cond is None and limit is None:
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/pandas/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,15 @@ def test_loc(self):
self.assertRaises(KeyError, lambda: psdf.loc[0:30])
self.assertRaises(KeyError, lambda: psdf.loc[10:100])

def test_loc_getitem_boolean_series(self):
pdf = pd.DataFrame(
{"A": [0, 1, 2, 3, 4], "B": [100, 200, 300, 400, 500]}, index=[20, 10, 30, 0, 50]
)
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.A.loc[pdf.B > 200], psdf.A.loc[psdf.B > 200])
self.assert_eq(pdf.B.loc[pdf.B > 200], psdf.B.loc[psdf.B > 200])
self.assert_eq(pdf.loc[pdf.B > 200], psdf.loc[psdf.B > 200])

def test_loc_non_informative_index(self):
pdf = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40])
psdf = ps.from_pandas(pdf)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/pandas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,12 @@ def test_loc_getitem_boolean_series(self):
(pdf1.A + 1).loc[pdf2.A > -3].sort_index(), (psdf1.A + 1).loc[psdf2.A > -3].sort_index()
)

pser = pd.Series([0, 1, 2, 3, 4], index=[20, 10, 30, 0, 50])
psser = ps.from_pandas(pser)
self.assert_eq(pser.loc[pdf2.A > -3].sort_index(), psser.loc[psdf2.A > -3].sort_index())
pser.name = psser.name = "B"
self.assert_eq(pser.loc[pdf2.A > -3].sort_index(), psser.loc[psdf2.A > -3].sort_index())

def test_bitwise(self):
pser1 = pd.Series([True, False, True, False, np.nan, np.nan, True, False, np.nan])
pser2 = pd.Series([True, False, False, True, True, False, np.nan, np.nan, np.nan])
Expand Down

0 comments on commit 6a5ee02

Please sign in to comment.