Skip to content

Commit 1cb49d9

Browse files
jrbourbeauj-bennet
authored andcommitted
Move into dedicated module
1 parent ba94e5b commit 1cb49d9

File tree

3 files changed

+92
-69
lines changed

3 files changed

+92
-69
lines changed

dask/dataframe/_pyarrow_utils.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pandas as pd
2+
3+
from dask.dataframe._compat import PANDAS_GT_130, PANDAS_GT_150
4+
from dask.dataframe.utils import is_dataframe_like, is_index_like, is_series_like
5+
6+
try:
7+
import pyarrow as pa
8+
except ImportError:
9+
pa = None
10+
11+
12+
PYARROW_STRINGS_AVAILABLE: bool = pa is not None and PANDAS_GT_130
13+
14+
15+
def is_pyarrow_string_dtype(dtype):
16+
if not PYARROW_STRINGS_AVAILABLE:
17+
return False
18+
19+
if PANDAS_GT_150:
20+
types = [pd.StringDtype("pyarrow"), pd.ArrowDtype(pa.string())]
21+
else:
22+
types = [pd.StringDtype("pyarrow")]
23+
if dtype in types:
24+
return True
25+
return False
26+
27+
28+
def is_object_string_dtype(dtype):
29+
"""Determine if input is a non-pyarrow string dtype"""
30+
return pd.api.types.is_string_dtype(dtype) and not is_pyarrow_string_dtype(dtype)
31+
32+
33+
def is_object_string_index(x):
34+
return (
35+
is_index_like(x)
36+
and is_object_string_dtype(x.dtype)
37+
and not isinstance(
38+
x, pd.MultiIndex
39+
) # MultiIndex don't support non-object dtypes
40+
)
41+
42+
43+
def is_object_string_series(x):
44+
return is_series_like(x) and (
45+
is_object_string_dtype(x.dtype) or is_object_string_index(x.index)
46+
)
47+
48+
49+
def is_object_string_dataframe(x):
50+
return is_dataframe_like(x) and (
51+
any(is_object_string_series(s) for _, s in x.items())
52+
or is_object_string_index(x.index)
53+
)
54+
55+
56+
def to_pyarrow_string(df):
57+
if not (is_dataframe_like(df) or is_series_like(df) or is_index_like(df)):
58+
return df
59+
60+
# Possibly convert DataFrame/Series/Index to `string[pyarrow]`
61+
dtypes = None
62+
if is_dataframe_like(df):
63+
dtypes = {
64+
col: pd.StringDtype("pyarrow")
65+
for col, s in df.items()
66+
if is_object_string_dtype(s.dtype)
67+
}
68+
elif is_object_string_dtype(df.dtype):
69+
dtypes = pd.StringDtype("pyarrow")
70+
71+
if dtypes is not None:
72+
df = df.astype(dtypes)
73+
74+
# Convert DataFrame/Series index too
75+
if (is_dataframe_like(df) or is_series_like(df)) and is_object_string_index(
76+
df.index
77+
):
78+
df.index = df.index.astype(pd.StringDtype("pyarrow"))
79+
return df

dask/dataframe/core.py

+12-68
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
is_bool_dtype,
1616
is_datetime64_any_dtype,
1717
is_numeric_dtype,
18-
is_string_dtype,
1918
is_timedelta64_dtype,
2019
)
2120
from tlz import first, merge, partition_all, remove, unique
@@ -323,72 +322,6 @@ def _scalar_binary(op, self, other, inv=False):
323322
return Scalar(graph, name, meta)
324323

325324

326-
def _is_pyarrow_string(dtype):
327-
if not PANDAS_GT_130:
328-
return False
329-
330-
if PANDAS_GT_150:
331-
import pyarrow as pa
332-
333-
types = [pd.StringDtype("pyarrow"), pd.ArrowDtype(pa.string())]
334-
else:
335-
types = [pd.StringDtype("pyarrow")]
336-
if dtype in types:
337-
return True
338-
return False
339-
340-
341-
def _is_object_string_dtype(dtype):
342-
"""Determine if input is a non-pyarrow string dtype"""
343-
return is_string_dtype(dtype) and not _is_pyarrow_string(dtype)
344-
345-
346-
def _index_check(x):
347-
return (
348-
is_index_like(x)
349-
and _is_object_string_dtype(x.dtype)
350-
and not isinstance(
351-
x, pd.MultiIndex
352-
) # MultiIndex don't support non-object dtypes
353-
)
354-
355-
356-
def _series_check(x):
357-
return is_series_like(x) and (
358-
_is_object_string_dtype(x.dtype) or _index_check(x.index)
359-
)
360-
361-
362-
def _dataframe_check(x):
363-
return is_dataframe_like(x) and (
364-
any(_series_check(s) for _, s in x.items()) or _index_check(x.index)
365-
)
366-
367-
368-
def to_pyarrow_string(df):
369-
if not (is_dataframe_like(df) or is_series_like(df) or is_index_like(df)):
370-
return df
371-
372-
# Possibly convert DataFrame/Series/Index to `string[pyarrow]`
373-
dtypes = None
374-
if is_dataframe_like(df):
375-
dtypes = {
376-
col: pd.StringDtype("pyarrow")
377-
for col, s in df.items()
378-
if _is_object_string_dtype(s.dtype)
379-
}
380-
elif _is_object_string_dtype(df.dtype):
381-
dtypes = pd.StringDtype("pyarrow")
382-
383-
if dtypes is not None:
384-
df = df.astype(dtypes)
385-
386-
# Convert DataFrame/Series index too
387-
if (is_dataframe_like(df) or is_series_like(df)) and _index_check(df.index):
388-
df.index = df.index.astype(pd.StringDtype("pyarrow"))
389-
return df
390-
391-
392325
class _Frame(DaskMethodsMixin, OperatorMethodMixin):
393326
"""Superclass for DataFrame and Series
394327
@@ -436,7 +369,18 @@ def __init__(self, dsk, name, meta, divisions):
436369
f"pandas={str(PANDAS_VERSION)} is currently using used."
437370
)
438371

439-
if _dataframe_check(meta) or _series_check(meta) or _index_check(meta):
372+
from dask.dataframe._pyarrow_utils import (
373+
is_object_string_dataframe,
374+
is_object_string_index,
375+
is_object_string_series,
376+
to_pyarrow_string,
377+
)
378+
379+
if (
380+
is_object_string_dataframe(meta)
381+
or is_object_string_series(meta)
382+
or is_object_string_index(meta)
383+
):
440384
result = self.map_partitions(to_pyarrow_string)
441385
self.dask = result.dask
442386
self._name = result._name

dask/dataframe/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def assert_eq(
545545
import dask
546546

547547
if dask.config.get("dataframe.object_as_pyarrow_string"):
548-
from dask.dataframe.core import to_pyarrow_string
548+
from dask.dataframe._pyarrow_utils import to_pyarrow_string
549549

550550
if not is_dask_collection(a):
551551
a = to_pyarrow_string(a)

0 commit comments

Comments
 (0)