Skip to content

Commit ba94e5b

Browse files
jrbourbeauj-bennet
authored andcommitted
Minor polish
1 parent 1152889 commit ba94e5b

File tree

2 files changed

+48
-35
lines changed

2 files changed

+48
-35
lines changed

dask/dataframe/core.py

+44-33
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_bool_dtype,
1616
is_datetime64_any_dtype,
1717
is_numeric_dtype,
18-
is_object_dtype,
18+
is_string_dtype,
1919
is_timedelta64_dtype,
2020
)
2121
from tlz import first, merge, partition_all, remove, unique
@@ -323,38 +323,68 @@ def _scalar_binary(op, self, other, inv=False):
323323
return Scalar(graph, name, meta)
324324

325325

326-
def _is_object_or_string_dtype(dtype):
327-
"""Determine if input dtype is `object` or `string[python]`"""
328-
if is_object_dtype(dtype) or (
329-
isinstance(dtype, pd.StringDtype) and dtype.storage == "python"
330-
):
326+
def _is_pyarrow_string(dtype):
327+
if not PANDAS_GT_130:
328+
return False
329+
330+
if PANDAS_GT_150:
331+
import pyarrow as pa
332+
333+
types = [pd.StringDtype("pyarrow"), pd.ArrowDtype(pa.string())]
334+
else:
335+
types = [pd.StringDtype("pyarrow")]
336+
if dtype in types:
331337
return True
332338
return False
333339

334340

341+
def _is_object_string_dtype(dtype):
342+
"""Determine if input is a non-pyarrow string dtype"""
343+
return is_string_dtype(dtype) and not _is_pyarrow_string(dtype)
344+
345+
346+
def _index_check(x):
347+
return (
348+
is_index_like(x)
349+
and _is_object_string_dtype(x.dtype)
350+
and not isinstance(
351+
x, pd.MultiIndex
352+
) # MultiIndex don't support non-object dtypes
353+
)
354+
355+
356+
def _series_check(x):
357+
return is_series_like(x) and (
358+
_is_object_string_dtype(x.dtype) or _index_check(x.index)
359+
)
360+
361+
362+
def _dataframe_check(x):
363+
return is_dataframe_like(x) and (
364+
any(_series_check(s) for _, s in x.items()) or _index_check(x.index)
365+
)
366+
367+
335368
def to_pyarrow_string(df):
336369
if not (is_dataframe_like(df) or is_series_like(df) or is_index_like(df)):
337370
return df
338371

372+
# Possibly convert DataFrame/Series/Index to `string[pyarrow]`
339373
dtypes = None
340374
if is_dataframe_like(df):
341375
dtypes = {
342376
col: pd.StringDtype("pyarrow")
343377
for col, s in df.items()
344-
if _is_object_or_string_dtype(s.dtype)
378+
if _is_object_string_dtype(s.dtype)
345379
}
346-
elif _is_object_or_string_dtype(df.dtype):
380+
elif _is_object_string_dtype(df.dtype):
347381
dtypes = pd.StringDtype("pyarrow")
348382

349383
if dtypes is not None:
350384
df = df.astype(dtypes)
351385

352-
# Convert DataFrame and Series index too
353-
if (
354-
(is_dataframe_like(df) or is_series_like(df))
355-
and not isinstance(df.index, pd.MultiIndex)
356-
and _is_object_or_string_dtype(df.index.dtype)
357-
):
386+
# Convert DataFrame/Series index too
387+
if (is_dataframe_like(df) or is_series_like(df)) and _index_check(df.index):
358388
df.index = df.index.astype(pd.StringDtype("pyarrow"))
359389
return df
360390

@@ -406,25 +436,6 @@ def __init__(self, dsk, name, meta, divisions):
406436
f"pandas={str(PANDAS_VERSION)} is currently using used."
407437
)
408438

409-
def _index_check(x):
410-
return (
411-
is_index_like(x)
412-
and _is_object_or_string_dtype(x)
413-
and not isinstance(
414-
x, pd.MultiIndex
415-
) # MultiIndex don't support non-object dtypes
416-
)
417-
418-
def _series_check(x):
419-
return is_series_like(x) and (
420-
_is_object_or_string_dtype(x) or _index_check(x.index)
421-
)
422-
423-
def _dataframe_check(x):
424-
return is_dataframe_like(x) and (
425-
any(_series_check(s) for _, s in x.items()) or _index_check(x.index)
426-
)
427-
428439
if _dataframe_check(meta) or _series_check(meta) or _index_check(meta):
429440
result = self.map_partitions(to_pyarrow_string)
430441
self.dask = result.dask

dask/dataframe/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,10 @@ def assert_eq(
547547
if dask.config.get("dataframe.object_as_pyarrow_string"):
548548
from dask.dataframe.core import to_pyarrow_string
549549

550-
a = to_pyarrow_string(a)
551-
b = to_pyarrow_string(b)
550+
if not is_dask_collection(a):
551+
a = to_pyarrow_string(a)
552+
if not is_dask_collection(b):
553+
b = to_pyarrow_string(b)
552554

553555
if check_divisions:
554556
assert_divisions(a, scheduler=scheduler)

0 commit comments

Comments
 (0)