Skip to content

Commit

Permalink
PyArrow 13 CI fixes (#6175)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko authored Aug 25, 2023
1 parent 4566827 commit 6e84937
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,7 +1916,7 @@ def _concat_arrays(arrays):
_concat_arrays([array.values for array in arrays]),
)
elif pa.types.is_fixed_size_list(array_type):
if config.PYARROW_VERSION.major < 13:
if config.PYARROW_VERSION.major < 14:
# PyArrow bug: https://github.com/apache/arrow/issues/35360
return pa.FixedSizeListArray.from_arrays(
_concat_arrays([array.values[array.offset * array.type.list_size :] for array in arrays]),
Expand Down Expand Up @@ -1993,7 +1993,7 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
return pa.ListArray.from_arrays(array.offsets, _c(array.values, pa_type.value_type))
elif pa.types.is_fixed_size_list(array.type):
array_values = array.values
if config.PYARROW_VERSION.major < 13:
if config.PYARROW_VERSION.major < 14:
# PyArrow bug: https://github.com/apache/arrow/issues/35360
array_values = array.values[array.offset * array.type.list_size :]
if pa.types.is_fixed_size_list(pa_type):
Expand Down Expand Up @@ -2109,7 +2109,7 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
elif pa.types.is_fixed_size_list(array.type):
# feature must be either [subfeature] or Sequence(subfeature)
array_values = array.values
if config.PYARROW_VERSION.major < 13:
if config.PYARROW_VERSION.major < 14:
# PyArrow bug: https://github.com/apache/arrow/issues/35360
array_values = array.values[array.offset * array.type.list_size :]
if isinstance(feature, list):
Expand Down Expand Up @@ -2216,7 +2216,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
elif pa.types.is_fixed_size_list(array.type):
# feature must be either [subfeature] or Sequence(subfeature)
array_values = array.values
if config.PYARROW_VERSION.major < 13:
if config.PYARROW_VERSION.major < 14:
# PyArrow bug: https://github.com/apache/arrow/issues/35360
array_values = array.values[array.offset * array.type.list_size :]
if isinstance(feature, list):
Expand Down
31 changes: 28 additions & 3 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_numpy_extractor(self):
np.testing.assert_equal(batch, {"a": np.array(_COL_A), "b": np.array(_COL_B)})

def test_numpy_extractor_nested(self):
pa_table = self._create_dummy_table().drop(["a", "b"])
pa_table = self._create_dummy_table().drop(["a", "b", "d"])
extractor = NumpyArrowExtractor()
row = extractor.extract_row(pa_table)
self.assertEqual(row["c"][0].dtype, np.float64)
Expand Down Expand Up @@ -109,14 +109,39 @@ def test_pandas_extractor(self):
self.assertIsInstance(row, pd.DataFrame)
pd.testing.assert_series_equal(row["a"], pd.Series(_COL_A, name="a")[:1])
pd.testing.assert_series_equal(row["b"], pd.Series(_COL_B, name="b")[:1])
pd.testing.assert_series_equal(row["d"], pd.Series(_COL_D, name="d")[:1])
col = extractor.extract_column(pa_table)
pd.testing.assert_series_equal(col, pd.Series(_COL_A, name="a"))
batch = extractor.extract_batch(pa_table)
self.assertIsInstance(batch, pd.DataFrame)
pd.testing.assert_series_equal(batch["a"], pd.Series(_COL_A, name="a"))
pd.testing.assert_series_equal(batch["b"], pd.Series(_COL_B, name="b"))
pd.testing.assert_series_equal(batch["d"], pd.Series(_COL_D, name="d"))

def test_pandas_extractor_nested(self):
pa_table = self._create_dummy_table().drop(["a", "b", "d"])
extractor = PandasArrowExtractor()
row = extractor.extract_row(pa_table)
self.assertEqual(row["c"][0][0].dtype, np.float64)
self.assertEqual(row["c"].dtype, object)
col = extractor.extract_column(pa_table)
self.assertEqual(col[0][0].dtype, np.float64)
self.assertEqual(col[0].dtype, object)
self.assertEqual(col.dtype, object)
batch = extractor.extract_batch(pa_table)
self.assertEqual(batch["c"][0][0].dtype, np.float64)
self.assertEqual(batch["c"][0].dtype, object)
self.assertEqual(batch["c"].dtype, object)

def test_pandas_extractor_temporal(self):
pa_table = self._create_dummy_table().drop(["a", "b", "c"])
extractor = PandasArrowExtractor()
row = extractor.extract_row(pa_table)
self.assertTrue(pd.api.types.is_datetime64_any_dtype(row["d"].dtype))
col = extractor.extract_column(pa_table)
self.assertTrue(isinstance(col[0], datetime.datetime))
self.assertTrue(pd.api.types.is_datetime64_any_dtype(col.dtype))
batch = extractor.extract_batch(pa_table)
self.assertTrue(isinstance(batch["d"][0], datetime.datetime))
self.assertTrue(pd.api.types.is_datetime64_any_dtype(batch["d"].dtype))


class LazyDictTest(TestCase):
Expand Down

0 comments on commit 6e84937

Please sign in to comment.