Skip to content

Commit 5181c24

Browse files
GH-43683: [Python] Use pandas StringDtype when enabled (pandas 3+) (#44195)
### Rationale for this change With pandas' [PDEP-14](https://pandas.pydata.org/pdeps/0014-string-dtype.html) proposal, pandas is planning to introduce a default string dtype in pandas 3.0 (instead of the current object dtype). This will become the default in pandas 3.0, and can be enabled with an option in the upcoming pandas 2.3 (`pd.options.future.infer_string = True`). To prepare for that, we should start using that string dtype in `to_pandas()` conversions when that option is enabled. ### What changes are included in this PR? - If pandas >= 3.0 is used or the pandas option is enabled, ensure that `to_pandas()` calls use the default string dtype of pandas for string-like columns (string, large_string, string_view) ### Are these changes tested? It is tested in the pandas-nightly crossbow build. There is still one failure that is because of a bug on the pandas side (pandas-dev/pandas#59879) ### Are there any user-facing changes? **This PR includes breaking changes to public APIs.** Depending on the version of pandas, `to_pandas()` will change to use pandas' string dtype instead of object dtype. This is a breaking user-facing change, but essentially just following the equivalent change in default dtype on the pandas side. * GitHub Issue: #43683 Lead-authored-by: Joris Van den Bossche <[email protected]> Co-authored-by: Raúl Cumplido <[email protected]> Signed-off-by: Joris Van den Bossche <[email protected]>
1 parent 3f0bd2f commit 5181c24

File tree

8 files changed

+155
-35
lines changed

8 files changed

+155
-35
lines changed

dev/tasks/tasks.yml

+6
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,12 @@ tasks:
14261426
# ensure we have at least one build with parquet encryption disabled
14271427
PARQUET_REQUIRE_ENCRYPTION: "OFF"
14281428
{% endif %}
1429+
{% if pandas_version == "nightly" %}
1430+
# TODO can be removed once this is enabled by default in pandas >= 3
1431+
# This is to enable the Pandas feature.
1432+
# See: https://github.com/pandas-dev/pandas/pull/58459
1433+
PANDAS_FUTURE_INFER_STRING: "1"
1434+
{% endif %}
14291435
{% if not cache_leaf %}
14301436
# use the latest pandas release, so prevent reusing any cached layers
14311437
flags: --no-leaf-cache

docker-compose.yml

+1
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,7 @@ services:
13751375
PYTEST_ARGS: # inherit
13761376
HYPOTHESIS_PROFILE: # inherit
13771377
PYARROW_TEST_HYPOTHESIS: # inherit
1378+
PANDAS_FUTURE_INFER_STRING: # inherit
13781379
volumes: *conda-volumes
13791380
command: *python-conda-command
13801381

python/pyarrow/array.pxi

+2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def _handle_arrow_array_protocol(obj, type, mask, size):
117117
"return a pyarrow Array or ChunkedArray.")
118118
if isinstance(res, ChunkedArray) and res.num_chunks==1:
119119
res = res.chunk(0)
120+
if type is not None and res.type != type:
121+
res = res.cast(type)
120122
return res
121123

122124

python/pyarrow/pandas-shim.pxi

+16-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ cdef class _PandasAPIShim(object):
3838
object _array_like_types, _is_extension_array_dtype, _lock
3939
bint has_sparse
4040
bint _pd024
41-
bint _is_v1, _is_ge_v21, _is_ge_v3
41+
bint _is_v1, _is_ge_v21, _is_ge_v3, _is_ge_v3_strict
4242

4343
def __init__(self):
4444
self._lock = Lock()
@@ -80,6 +80,7 @@ cdef class _PandasAPIShim(object):
8080
self._is_v1 = self._loose_version < Version('2.0.0')
8181
self._is_ge_v21 = self._loose_version >= Version('2.1.0')
8282
self._is_ge_v3 = self._loose_version >= Version('3.0.0.dev0')
83+
self._is_ge_v3_strict = self._loose_version >= Version('3.0.0')
8384

8485
self._compat_module = pdcompat
8586
self._data_frame = pd.DataFrame
@@ -174,6 +175,20 @@ cdef class _PandasAPIShim(object):
174175
self._check_import()
175176
return self._is_ge_v3
176177

178+
def is_ge_v3_strict(self):
179+
self._check_import()
180+
return self._is_ge_v3_strict
181+
182+
def uses_string_dtype(self):
183+
if self.is_ge_v3_strict():
184+
return True
185+
try:
186+
if self.pd.options.future.infer_string:
187+
return True
188+
except:
189+
pass
190+
return False
191+
177192
@property
178193
def categorical_type(self):
179194
self._check_import()

python/pyarrow/pandas_compat.py

+54-8
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ def get_column_metadata(column, name, arrow_type, field_name):
174174
}
175175
string_dtype = 'object'
176176

177-
if name is not None and not isinstance(name, str):
177+
if (
178+
name is not None
179+
and not (isinstance(name, float) and np.isnan(name))
180+
and not isinstance(name, str)
181+
):
178182
raise TypeError(
179183
'Column name must be a string. Got column {} of type {}'.format(
180184
name, type(name).__name__
@@ -340,8 +344,8 @@ def _column_name_to_strings(name):
340344
return str(tuple(map(_column_name_to_strings, name)))
341345
elif isinstance(name, Sequence):
342346
raise TypeError("Unsupported type for MultiIndex level")
343-
elif name is None:
344-
return None
347+
elif name is None or (isinstance(name, float) and np.isnan(name)):
348+
return name
345349
return str(name)
346350

347351

@@ -790,10 +794,12 @@ def table_to_dataframe(
790794
table, index = _reconstruct_index(table, index_descriptors,
791795
all_columns, types_mapper)
792796
ext_columns_dtypes = _get_extension_dtypes(
793-
table, all_columns, types_mapper)
797+
table, all_columns, types_mapper, options, categories)
794798
else:
795799
index = _pandas_api.pd.RangeIndex(table.num_rows)
796-
ext_columns_dtypes = _get_extension_dtypes(table, [], types_mapper)
800+
ext_columns_dtypes = _get_extension_dtypes(
801+
table, [], types_mapper, options, categories
802+
)
797803

798804
_check_data_column_metadata_consistency(all_columns)
799805
columns = _deserialize_column_index(table, all_columns, column_indexes)
@@ -838,7 +844,7 @@ def table_to_dataframe(
838844
}
839845

840846

841-
def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
847+
def _get_extension_dtypes(table, columns_metadata, types_mapper, options, categories):
842848
"""
843849
Based on the stored column pandas metadata and the extension types
844850
in the arrow schema, infer which columns should be converted to a
@@ -851,6 +857,9 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
851857
and then we can check if this dtype supports conversion from arrow.
852858
853859
"""
860+
strings_to_categorical = options["strings_to_categorical"]
861+
categories = categories or []
862+
854863
ext_columns = {}
855864

856865
# older pandas version that does not yet support extension dtypes
@@ -889,9 +898,32 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
889898
# that are certainly numpy dtypes
890899
pandas_dtype = _pandas_api.pandas_dtype(dtype)
891900
if isinstance(pandas_dtype, _pandas_api.extension_dtype):
901+
if isinstance(pandas_dtype, _pandas_api.pd.StringDtype):
902+
# when the metadata indicate to use the string dtype,
903+
# ignore this in case:
904+
# - it is specified to convert strings / this column to categorical
905+
# - the column itself is dictionary encoded and would otherwise be
906+
# converted to categorical
907+
if strings_to_categorical or name in categories:
908+
continue
909+
try:
910+
if pa.types.is_dictionary(table.schema.field(name).type):
911+
continue
912+
except KeyError:
913+
pass
892914
if hasattr(pandas_dtype, "__from_arrow__"):
893915
ext_columns[name] = pandas_dtype
894916

917+
# for pandas 3.0+, use pandas' new default string dtype
918+
if _pandas_api.uses_string_dtype() and not strings_to_categorical:
919+
for field in table.schema:
920+
if field.name not in ext_columns and (
921+
pa.types.is_string(field.type)
922+
or pa.types.is_large_string(field.type)
923+
or pa.types.is_string_view(field.type)
924+
) and field.name not in categories:
925+
ext_columns[field.name] = _pandas_api.pd.StringDtype(na_value=np.nan)
926+
895927
return ext_columns
896928

897929

@@ -1049,9 +1081,9 @@ def get_pandas_logical_type_map():
10491081
'date': 'datetime64[D]',
10501082
'datetime': 'datetime64[ns]',
10511083
'datetimetz': 'datetime64[ns]',
1052-
'unicode': np.str_,
1084+
'unicode': 'str',
10531085
'bytes': np.bytes_,
1054-
'string': np.str_,
1086+
'string': 'str',
10551087
'integer': np.int64,
10561088
'floating': np.float64,
10571089
'decimal': np.object_,
@@ -1142,6 +1174,20 @@ def _reconstruct_columns_from_metadata(columns, column_indexes):
11421174
# GH-41503: if the column index was decimal, restore to decimal
11431175
elif pandas_dtype == "decimal":
11441176
level = _pandas_api.pd.Index([decimal.Decimal(i) for i in level])
1177+
elif (
1178+
level.dtype == "str" and numpy_dtype == "object"
1179+
and ("mixed" in pandas_dtype or pandas_dtype in ["unicode", "string"])
1180+
):
1181+
# the metadata indicate that the original dataframe used object dtype,
1182+
# but ignore this and keep string dtype if:
1183+
# - the original columns used mixed types -> we don't attempt to faithfully
1184+
# roundtrip in this case, but keep the column names as strings
1185+
# - the original columns were inferred to be strings but stored in object
1186+
# dtype -> we don't restore the object dtype because all metadata
1187+
# generated using pandas < 3 will have this case by default, and
1188+
# for pandas >= 3 we want to use the default string dtype for .columns
1189+
new_levels.append(level)
1190+
continue
11451191
elif level.dtype != dtype:
11461192
level = level.astype(dtype)
11471193
# ARROW-9096: if original DataFrame was upcast we keep that

python/pyarrow/tests/test_compute.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def test_replace_slice():
10201020
offsets = range(-3, 4)
10211021

10221022
arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd', 'abcde'])
1023-
series = arr.to_pandas()
1023+
series = arr.to_pandas().astype(object).replace({np.nan: None})
10241024
for start in offsets:
10251025
for stop in offsets:
10261026
expected = series.str.slice_replace(start, stop, 'XX')
@@ -1031,7 +1031,7 @@ def test_replace_slice():
10311031
assert pc.binary_replace_slice(arr, start, stop, 'XX') == actual
10321032

10331033
arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd', 'πbθde'])
1034-
series = arr.to_pandas()
1034+
series = arr.to_pandas().astype(object).replace({np.nan: None})
10351035
for start in offsets:
10361036
for stop in offsets:
10371037
expected = series.str.slice_replace(start, stop, 'XX')
@@ -2132,50 +2132,51 @@ def test_strftime():
21322132
for fmt in formats:
21332133
options = pc.StrftimeOptions(fmt)
21342134
result = pc.strftime(tsa, options=options)
2135-
expected = pa.array(ts.strftime(fmt))
2135+
# cast to the same type as result to ignore string vs large_string
2136+
expected = pa.array(ts.strftime(fmt)).cast(result.type)
21362137
assert result.equals(expected)
21372138

21382139
fmt = "%Y-%m-%dT%H:%M:%S"
21392140

21402141
# Default format
21412142
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
21422143
result = pc.strftime(tsa, options=pc.StrftimeOptions())
2143-
expected = pa.array(ts.strftime(fmt))
2144+
expected = pa.array(ts.strftime(fmt)).cast(result.type)
21442145
assert result.equals(expected)
21452146

21462147
# Default format plus timezone
21472148
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
21482149
result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%Z"))
2149-
expected = pa.array(ts.strftime(fmt + "%Z"))
2150+
expected = pa.array(ts.strftime(fmt + "%Z")).cast(result.type)
21502151
assert result.equals(expected)
21512152

21522153
# Pandas %S is equivalent to %S in arrow for unit="s"
21532154
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
21542155
options = pc.StrftimeOptions("%S")
21552156
result = pc.strftime(tsa, options=options)
2156-
expected = pa.array(ts.strftime("%S"))
2157+
expected = pa.array(ts.strftime("%S")).cast(result.type)
21572158
assert result.equals(expected)
21582159

21592160
# Pandas %S.%f is equivalent to %S in arrow for unit="us"
21602161
tsa = pa.array(ts, type=pa.timestamp("us", timezone))
21612162
options = pc.StrftimeOptions("%S")
21622163
result = pc.strftime(tsa, options=options)
2163-
expected = pa.array(ts.strftime("%S.%f"))
2164+
expected = pa.array(ts.strftime("%S.%f")).cast(result.type)
21642165
assert result.equals(expected)
21652166

21662167
# Test setting locale
21672168
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
21682169
options = pc.StrftimeOptions(fmt, locale="C")
21692170
result = pc.strftime(tsa, options=options)
2170-
expected = pa.array(ts.strftime(fmt))
2171+
expected = pa.array(ts.strftime(fmt)).cast(result.type)
21712172
assert result.equals(expected)
21722173

21732174
# Test timestamps without timezone
21742175
fmt = "%Y-%m-%dT%H:%M:%S"
21752176
ts = pd.to_datetime(times)
21762177
tsa = pa.array(ts, type=pa.timestamp("s"))
21772178
result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt))
2178-
expected = pa.array(ts.strftime(fmt))
2179+
expected = pa.array(ts.strftime(fmt)).cast(result.type)
21792180

21802181
# Positional format
21812182
assert pc.strftime(tsa, fmt) == result

python/pyarrow/tests/test_feather.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,11 @@ def test_empty_strings(version):
426426
@pytest.mark.pandas
427427
def test_all_none(version):
428428
df = pd.DataFrame({'all_none': [None] * 10})
429-
_check_pandas_roundtrip(df, version=version)
429+
if version == 1 and pa.pandas_compat._pandas_api.uses_string_dtype():
430+
expected = df.astype("str")
431+
else:
432+
expected = df
433+
_check_pandas_roundtrip(df, version=version, expected=expected)
430434

431435

432436
@pytest.mark.pandas

0 commit comments

Comments
 (0)