From 6e84937af4f24194bf61f09244ebef6528fb7c4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Fri, 25 Aug 2023 15:06:52 +0200 Subject: [PATCH] PyArrow 13 CI fixes (#6175) --- src/datasets/table.py | 8 ++++---- tests/test_formatting.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index d0982a34c3f..13fc26bd97e 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1916,7 +1916,7 @@ def _concat_arrays(arrays): _concat_arrays([array.values for array in arrays]), ) elif pa.types.is_fixed_size_list(array_type): - if config.PYARROW_VERSION.major < 13: + if config.PYARROW_VERSION.major < 14: # PyArrow bug: https://github.com/apache/arrow/issues/35360 return pa.FixedSizeListArray.from_arrays( _concat_arrays([array.values[array.offset * array.type.list_size :] for array in arrays]), @@ -1993,7 +1993,7 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True): return pa.ListArray.from_arrays(array.offsets, _c(array.values, pa_type.value_type)) elif pa.types.is_fixed_size_list(array.type): array_values = array.values - if config.PYARROW_VERSION.major < 13: + if config.PYARROW_VERSION.major < 14: # PyArrow bug: https://github.com/apache/arrow/issues/35360 array_values = array.values[array.offset * array.type.list_size :] if pa.types.is_fixed_size_list(pa_type): @@ -2109,7 +2109,7 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_ elif pa.types.is_fixed_size_list(array.type): # feature must be either [subfeature] or Sequence(subfeature) array_values = array.values - if config.PYARROW_VERSION.major < 13: + if config.PYARROW_VERSION.major < 14: # PyArrow bug: https://github.com/apache/arrow/issues/35360 array_values = array.values[array.offset * array.type.list_size :] if isinstance(feature, list): @@ -2216,7 +2216,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"): elif pa.types.is_fixed_size_list(array.type): # feature must be either [subfeature] or Sequence(subfeature) array_values = array.values - if config.PYARROW_VERSION.major < 13: + if config.PYARROW_VERSION.major < 14: # PyArrow bug: https://github.com/apache/arrow/issues/35360 array_values = array.values[array.offset * array.type.list_size :] if isinstance(feature, list): diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 7e4af7c525c..9ac7d2c2f58 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -76,7 +76,7 @@ def test_numpy_extractor(self): np.testing.assert_equal(batch, {"a": np.array(_COL_A), "b": np.array(_COL_B)}) def test_numpy_extractor_nested(self): - pa_table = self._create_dummy_table().drop(["a", "b"]) + pa_table = self._create_dummy_table().drop(["a", "b", "d"]) extractor = NumpyArrowExtractor() row = extractor.extract_row(pa_table) self.assertEqual(row["c"][0].dtype, np.float64) @@ -109,14 +109,39 @@ def test_pandas_extractor(self): self.assertIsInstance(row, pd.DataFrame) pd.testing.assert_series_equal(row["a"], pd.Series(_COL_A, name="a")[:1]) pd.testing.assert_series_equal(row["b"], pd.Series(_COL_B, name="b")[:1]) - pd.testing.assert_series_equal(row["d"], pd.Series(_COL_D, name="d")[:1]) col = extractor.extract_column(pa_table) pd.testing.assert_series_equal(col, pd.Series(_COL_A, name="a")) batch = extractor.extract_batch(pa_table) self.assertIsInstance(batch, pd.DataFrame) pd.testing.assert_series_equal(batch["a"], pd.Series(_COL_A, name="a")) pd.testing.assert_series_equal(batch["b"], pd.Series(_COL_B, name="b")) - pd.testing.assert_series_equal(batch["d"], pd.Series(_COL_D, name="d")) + + def test_pandas_extractor_nested(self): + pa_table = self._create_dummy_table().drop(["a", "b", "d"]) + extractor = PandasArrowExtractor() + row = extractor.extract_row(pa_table) + self.assertEqual(row["c"][0][0].dtype, np.float64) + self.assertEqual(row["c"].dtype, object) + col = extractor.extract_column(pa_table) + self.assertEqual(col[0][0].dtype, np.float64) + self.assertEqual(col[0].dtype, object) + self.assertEqual(col.dtype, object) + batch = extractor.extract_batch(pa_table) + self.assertEqual(batch["c"][0][0].dtype, np.float64) + self.assertEqual(batch["c"][0].dtype, object) + self.assertEqual(batch["c"].dtype, object) + + def test_pandas_extractor_temporal(self): + pa_table = self._create_dummy_table().drop(["a", "b", "c"]) + extractor = PandasArrowExtractor() + row = extractor.extract_row(pa_table) + self.assertTrue(pd.api.types.is_datetime64_any_dtype(row["d"].dtype)) + col = extractor.extract_column(pa_table) + self.assertTrue(isinstance(col[0], datetime.datetime)) + self.assertTrue(pd.api.types.is_datetime64_any_dtype(col.dtype)) + batch = extractor.extract_batch(pa_table) + self.assertTrue(isinstance(batch["d"][0], datetime.datetime)) + self.assertTrue(pd.api.types.is_datetime64_any_dtype(batch["d"].dtype)) class LazyDictTest(TestCase):