From 6a5ee0283cee29a965e393fd829ff3cd0b09cf4d Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 22 Sep 2021 12:52:56 -0700 Subject: [PATCH] [SPARK-36818][PYTHON] Fix filtering a Series by a boolean Series ### What changes were proposed in this pull request? Fix filtering a Series (without a name) by a boolean Series. ### Why are the changes needed? A bugfix. The issue is raised as https://github.com/databricks/koalas/issues/2199. ### Does this PR introduce _any_ user-facing change? Yes. #### From ```py >>> psser = ps.Series([0, 1, 2, 3, 4]) >>> ps.set_option('compute.ops_on_diff_frames', True) >>> psser.loc[ps.Series([True, True, True, False, False])] Traceback (most recent call last): ... KeyError: 'none key' ``` #### To ```py >>> psser = ps.Series([0, 1, 2, 3, 4]) >>> ps.set_option('compute.ops_on_diff_frames', True) >>> psser.loc[ps.Series([True, True, True, False, False])] 0 0 1 1 2 2 dtype: int64 ``` ### How was this patch tested? Unit test. Closes #34061 from xinrong-databricks/filter_series. Authored-by: Xinrong Meng Signed-off-by: Takuya UESHIN --- python/pyspark/pandas/indexing.py | 6 ++++-- python/pyspark/pandas/tests/test_indexing.py | 9 +++++++++ python/pyspark/pandas/tests/test_ops_on_diff_frames.py | 6 ++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index 3e0797595b50c..e7a07e763c341 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -33,6 +33,7 @@ from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import Label, Name, Scalar from pyspark.pandas.internal import ( + DEFAULT_SERIES_NAME, InternalField, InternalFrame, NATURAL_ORDER_COLUMN_NAME, @@ -435,11 +436,12 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]: if self._is_series: if isinstance(key, Series) and not same_anchor(key, self._psdf_or_psser): - psdf = self._psdf_or_psser.to_frame() + name = self._psdf_or_psser.name or DEFAULT_SERIES_NAME + psdf = self._psdf_or_psser.to_frame(name) temp_col = verify_temp_column_name(psdf, "__temp_col__") psdf[temp_col] = key - return type(self)(psdf[self._psdf_or_psser.name])[psdf[temp_col]] + return type(self)(psdf[name].rename(self._psdf_or_psser.name))[psdf[temp_col]] cond, limit, remaining_index = self._select_rows(key) if cond is None and limit is None: diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py index b74cf90d079f9..2b00b3f952bb6 100644 --- a/python/pyspark/pandas/tests/test_indexing.py +++ b/python/pyspark/pandas/tests/test_indexing.py @@ -417,6 +417,15 @@ def test_loc(self): self.assertRaises(KeyError, lambda: psdf.loc[0:30]) self.assertRaises(KeyError, lambda: psdf.loc[10:100]) + def test_loc_getitem_boolean_series(self): + pdf = pd.DataFrame( + {"A": [0, 1, 2, 3, 4], "B": [100, 200, 300, 400, 500]}, index=[20, 10, 30, 0, 50] + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.A.loc[pdf.B > 200], psdf.A.loc[psdf.B > 200]) + self.assert_eq(pdf.B.loc[pdf.B > 200], psdf.B.loc[psdf.B > 200]) + self.assert_eq(pdf.loc[pdf.B > 200], psdf.loc[psdf.B > 200]) + def test_loc_non_informative_index(self): pdf = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40]) psdf = ps.from_pandas(pdf) diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py index 6a4855c7525ca..1cc0ff51b8a32 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py @@ -503,6 +503,12 @@ def test_loc_getitem_boolean_series(self): (pdf1.A + 1).loc[pdf2.A > -3].sort_index(), (psdf1.A + 1).loc[psdf2.A > -3].sort_index() ) + pser = pd.Series([0, 1, 2, 3, 4], index=[20, 10, 30, 0, 50]) + psser = ps.from_pandas(pser) + self.assert_eq(pser.loc[pdf2.A > -3].sort_index(), psser.loc[psdf2.A > -3].sort_index()) + pser.name = psser.name = "B" + self.assert_eq(pser.loc[pdf2.A > -3].sort_index(), psser.loc[psdf2.A > -3].sort_index()) + def test_bitwise(self): pser1 = pd.Series([True, False, True, False, np.nan, np.nan, True, False, np.nan]) pser2 = pd.Series([True, False, False, True, True, False, np.nan, np.nan, np.nan])