|
2 | 2 | import warnings |
3 | 3 | from importlib import import_module |
4 | 4 | from pathlib import Path |
5 | | -from typing import Any, Union |
| 5 | +from typing import Union |
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 |
|
|
19 | 19 | pandas = None |
20 | 20 | pyarrow_dtype = None |
21 | 21 |
|
22 | | -# Check if pandas has arrow dtypes enabled |
23 | | -try: |
24 | | - from pandas.compat import pa_version_under7p0 |
25 | 22 |
|
26 | | - pyarrow_dtypes_enabled = not pa_version_under7p0 |
27 | | -except ImportError: |
28 | | - pyarrow_dtypes_enabled = False |
| 23 | +# Version-aware helpers for Pandas 2.x vs 3.0 compatibility |
| 24 | +def _get_pandas_ge_3(): |
| 25 | + if pandas is None: |
| 26 | + return False |
| 27 | + from packaging.version import Version |
| 28 | + |
| 29 | + return Version(pandas.__version__) >= Version("3.0.0") |
| 30 | + |
| 31 | + |
| 32 | +PANDAS_GE_3 = _get_pandas_ge_3() |
| 33 | + |
| 34 | + |
| 35 | +def is_string_dtype(dtype): |
| 36 | + """Check if a dtype is a string dtype (works across Pandas 2.x and 3.0). |
| 37 | +
|
| 38 | + Uses pd.api.types.is_string_dtype() which handles: |
| 39 | + - Pandas 2.x: object dtype for strings |
| 40 | + - Pandas 3.0+: str (StringDtype) for strings |
| 41 | + """ |
| 42 | + return pandas.api.types.is_string_dtype(dtype) |
29 | 43 |
|
30 | 44 |
|
31 | 45 | def import_pandas(): |
@@ -113,78 +127,6 @@ def pandas_supports_arrow_backend(): |
113 | 127 | return pandas_2_or_higher() |
114 | 128 |
|
115 | 129 |
|
116 | | -def numpy_pandas_df(*args, **kwargs): |
117 | | - return import_pandas().DataFrame(*args, **kwargs) |
118 | | - |
119 | | - |
120 | | -def arrow_pandas_df(*args, **kwargs): |
121 | | - df = numpy_pandas_df(*args, **kwargs) |
122 | | - return df.convert_dtypes(dtype_backend="pyarrow") |
123 | | - |
124 | | - |
125 | | -class NumpyPandas: |
126 | | - def __init__(self) -> None: |
127 | | - self.backend = "numpy_nullable" |
128 | | - self.DataFrame = numpy_pandas_df |
129 | | - self.pandas = import_pandas() |
130 | | - |
131 | | - def __getattr__(self, name: str) -> Any: # noqa: ANN401 |
132 | | - return getattr(self.pandas, name) |
133 | | - |
134 | | - |
135 | | -def convert_arrow_to_numpy_backend(df): |
136 | | - names = df.columns |
137 | | - df_content = {} |
138 | | - for name in names: |
139 | | - df_content[name] = df[name].array.__arrow_array__() |
140 | | - # This should convert the pyarrow chunked arrays into numpy arrays |
141 | | - return import_pandas().DataFrame(df_content) |
142 | | - |
143 | | - |
144 | | -def convert_to_numpy(df): |
145 | | - if ( |
146 | | - pyarrow_dtypes_enabled |
147 | | - and pyarrow_dtype is not None |
148 | | - and any(True for x in df.dtypes if isinstance(x, pyarrow_dtype)) |
149 | | - ): |
150 | | - return convert_arrow_to_numpy_backend(df) |
151 | | - return df |
152 | | - |
153 | | - |
154 | | -def convert_and_equal(df1, df2, **kwargs): |
155 | | - df1 = convert_to_numpy(df1) |
156 | | - df2 = convert_to_numpy(df2) |
157 | | - import_pandas().testing.assert_frame_equal(df1, df2, **kwargs) |
158 | | - |
159 | | - |
160 | | -class ArrowMockTesting: |
161 | | - def __init__(self) -> None: |
162 | | - self.testing = import_pandas().testing |
163 | | - self.assert_frame_equal = convert_and_equal |
164 | | - |
165 | | - def __getattr__(self, name: str) -> Any: # noqa: ANN401 |
166 | | - return getattr(self.testing, name) |
167 | | - |
168 | | - |
169 | | -# This converts dataframes constructed with 'DataFrame(...)' to pyarrow backed dataframes |
170 | | -# Assert equal does the opposite, turning all pyarrow backed dataframes into numpy backed ones |
171 | | -# this is done because we don't produce pyarrow backed dataframes yet |
172 | | -class ArrowPandas: |
173 | | - def __init__(self) -> None: |
174 | | - self.pandas = import_pandas() |
175 | | - if pandas_2_or_higher() and pyarrow_dtypes_enabled: |
176 | | - self.backend = "pyarrow" |
177 | | - self.DataFrame = arrow_pandas_df |
178 | | - else: |
179 | | - # For backwards compatible reasons, just mock regular pandas |
180 | | - self.backend = "numpy_nullable" |
181 | | - self.DataFrame = self.pandas.DataFrame |
182 | | - self.testing = ArrowMockTesting() |
183 | | - |
184 | | - def __getattr__(self, name: str) -> Any: # noqa: ANN401 |
185 | | - return getattr(self.pandas, name) |
186 | | - |
187 | | - |
188 | 130 | @pytest.fixture |
189 | 131 | def require(): |
190 | 132 | def _require(extension_name, db_name="") -> Union[duckdb.DuckDBPyConnection, None]: |
|
0 commit comments