Skip to content

Commit 1152889

Browse files
jrbourbeauj-bennet
authored andcommitted
Add simple tests
1 parent 5ed196b commit 1152889

File tree

5 files changed

+114
-67
lines changed

5 files changed

+114
-67
lines changed

conftest.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@
2626

2727
try:
2828
import pandas # noqa: F401
29+
30+
# Temporary changes to look for pyarrow string failures
31+
import dask
32+
from dask.dataframe._compat import PANDAS_GT_130
33+
34+
try:
35+
import pyarrow
36+
except ImportError:
37+
pyarrow = False
38+
39+
if PANDAS_GT_130 and pyarrow:
40+
dask.config.set({"dataframe.object_as_pyarrow_string": True})
2941
except ImportError:
3042
collect_ignore_glob.append("dask/dataframe/*")
3143

@@ -68,16 +80,3 @@ def pytest_runtest_setup(item):
6880
def shuffle_method(request):
6981
with dask.config.set(shuffle=request.param):
7082
yield request.param
71-
72-
73-
# Temporary changes to look for pyarrow string failures
74-
import dask
75-
from dask.dataframe._compat import PANDAS_GT_130
76-
77-
try:
78-
import pyarrow
79-
except ImportError:
80-
pyarrow = False
81-
82-
if PANDAS_GT_130 and pyarrow:
83-
dask.config.set({"dataframe.object_as_pyarrow_string": True})

dask/dataframe/core.py

+42-32
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,40 @@ def _scalar_binary(op, self, other, inv=False):
323323
return Scalar(graph, name, meta)
324324

325325

326-
def _maybe_convert_dtype(dtype):
327-
if is_object_dtype(dtype):
328-
return pd.StringDtype("pyarrow")
329-
else:
330-
return dtype
326+
def _is_object_or_string_dtype(dtype):
327+
"""Determine if input dtype is `object` or `string[python]`"""
328+
if is_object_dtype(dtype) or (
329+
isinstance(dtype, pd.StringDtype) and dtype.storage == "python"
330+
):
331+
return True
332+
return False
333+
334+
335+
def to_pyarrow_string(df):
336+
if not (is_dataframe_like(df) or is_series_like(df) or is_index_like(df)):
337+
return df
338+
339+
dtypes = None
340+
if is_dataframe_like(df):
341+
dtypes = {
342+
col: pd.StringDtype("pyarrow")
343+
for col, s in df.items()
344+
if _is_object_or_string_dtype(s.dtype)
345+
}
346+
elif _is_object_or_string_dtype(df.dtype):
347+
dtypes = pd.StringDtype("pyarrow")
348+
349+
if dtypes is not None:
350+
df = df.astype(dtypes)
351+
352+
# Convert DataFrame and Series index too
353+
if (
354+
(is_dataframe_like(df) or is_series_like(df))
355+
and not isinstance(df.index, pd.MultiIndex)
356+
and _is_object_or_string_dtype(df.index.dtype)
357+
):
358+
df.index = df.index.astype(pd.StringDtype("pyarrow"))
359+
return df
331360

332361

333362
class _Frame(DaskMethodsMixin, OperatorMethodMixin):
@@ -378,45 +407,26 @@ def __init__(self, dsk, name, meta, divisions):
378407
)
379408

380409
def _index_check(x):
381-
# MultiIndex don't support non-object dtypes
382410
return (
383411
is_index_like(x)
384-
and is_object_dtype(x)
385-
and not isinstance(x, pd.MultiIndex)
412+
and _is_object_or_string_dtype(x)
413+
and not isinstance(
414+
x, pd.MultiIndex
415+
) # MultiIndex don't support non-object dtypes
386416
)
387417

388418
def _series_check(x):
389419
return is_series_like(x) and (
390-
is_object_dtype(x) or _index_check(x.index)
420+
_is_object_or_string_dtype(x) or _index_check(x.index)
391421
)
392422

393-
def _df_check(x):
423+
def _dataframe_check(x):
394424
return is_dataframe_like(x) and (
395425
any(_series_check(s) for _, s in x.items()) or _index_check(x.index)
396426
)
397427

398-
if _df_check(meta) or _series_check(meta) or _index_check(meta):
399-
400-
def _object_to_pyarrow_string(df):
401-
if not (
402-
is_dataframe_like(df) or is_series_like(df) or is_index_like(df)
403-
):
404-
return df
405-
if is_dataframe_like(df):
406-
dtypes = {
407-
col: _maybe_convert_dtype(df[col].dtype) for col in df
408-
}
409-
else:
410-
dtypes = _maybe_convert_dtype(df.dtype)
411-
df = df.astype(dtypes)
412-
413-
if (is_dataframe_like(df) or is_series_like(df)) and not isinstance(
414-
df.index, pd.MultiIndex
415-
):
416-
df.index = df.index.astype(_maybe_convert_dtype(df.index.dtype))
417-
return df
418-
419-
result = self.map_partitions(_object_to_pyarrow_string)
428+
if _dataframe_check(meta) or _series_check(meta) or _index_check(meta):
429+
result = self.map_partitions(to_pyarrow_string)
420430
self.dask = result.dask
421431
self._name = result._name
422432
self._meta = result._meta

dask/dataframe/io/tests/test_csv.py

+15
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,21 @@ def test_read_csv(dd_read, pd_read, text, sep):
360360
assert_eq(result, pd_read(fn, sep=sep))
361361

362362

363+
def test_read_csv_object_as_pyarrow_string_config():
364+
pytest.importorskip(
365+
"pandas",
366+
minversion="1.3.0",
367+
reason="Requires support for pyarrow strings",
368+
)
369+
pytest.importorskip("pyarrow", reason="Requires pyarrow")
370+
with filetext(csv_text) as fn:
371+
df = pd.read_csv(fn)
372+
with dask.config.set({"dataframe.object_as_pyarrow_string": True}):
373+
ddf = dd.read_csv(fn)
374+
df_pyarrow = df.astype({"name": "string[pyarrow]"})
375+
assert_eq(df_pyarrow, ddf, check_index=False)
376+
377+
363378
@pytest.mark.parametrize(
364379
"dd_read,pd_read,text,skip",
365380
[

dask/dataframe/io/tests/test_io.py

+40
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pytest
66

7+
import dask
78
import dask.array as da
89
import dask.dataframe as dd
910
from dask import config
@@ -274,6 +275,45 @@ def test_from_pandas_npartitions_duplicates(index):
274275
assert ddf.divisions == ("A", "B", "C", "C")
275276

276277

278+
def test_from_pandas_object_as_pyarrow_string_config():
279+
pytest.importorskip(
280+
"pandas",
281+
minversion="1.3.0",
282+
reason="Requires support for pyarrow strings",
283+
)
284+
pytest.importorskip("pyarrow", reason="Requires pyarrow")
285+
286+
# `dataframe.object_as_pyarrow_string` defaults to `False`
287+
s = pd.Series(["foo", "bar", "ricky", "bobby"], index=["a", "b", "c", "d"])
288+
df = pd.DataFrame(
289+
{
290+
"x": [1, 2, 3, 4],
291+
"y": [5.0, 6.0, 7.0, 8.0],
292+
"z": ["foo", "bar", "ricky", "bobby"],
293+
},
294+
index=["a", "b", "c", "d"],
295+
)
296+
297+
ds = dd.from_pandas(s, npartitions=2)
298+
ddf = dd.from_pandas(df, npartitions=2)
299+
300+
assert_eq(s, ds)
301+
assert_eq(df, ddf)
302+
303+
# When `dataframe.object_as_pyarrow_string = True`, dask should automatically
304+
# cast `object`s to pyarrow strings
305+
with dask.config.set({"dataframe.object_as_pyarrow_string": True}):
306+
ds = dd.from_pandas(s, npartitions=2)
307+
ddf = dd.from_pandas(df, npartitions=2)
308+
309+
s_pyarrow = s.astype("string[pyarrow]")
310+
s_pyarrow.index = s_pyarrow.index.astype("string[pyarrow]")
311+
df_pyarrow = df.astype({"z": "string[pyarrow]"})
312+
df_pyarrow.index = df_pyarrow.index.astype("string[pyarrow]")
313+
assert_eq(s_pyarrow, ds)
314+
assert_eq(df_pyarrow, ddf)
315+
316+
277317
@pytest.mark.gpu
278318
def test_gpu_from_pandas_npartitions_duplicates():
279319
cudf = pytest.importorskip("cudf")

dask/dataframe/utils.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
import pandas as pd
15-
from pandas.api.types import is_categorical_dtype, is_dtype_equal, is_object_dtype
15+
from pandas.api.types import is_categorical_dtype, is_dtype_equal
1616

1717
from dask.base import get_scheduler, is_dask_collection
1818
from dask.core import get_deps
@@ -530,25 +530,6 @@ def _maybe_sort(a, check_index: bool):
530530
return a.sort_index() if check_index else a
531531

532532

533-
def _maybe_convert_to_pyarrow(a):
534-
from dask.dataframe.core import _maybe_convert_dtype
535-
536-
if isinstance(a, pd.DataFrame):
537-
dtypes = {col: _maybe_convert_dtype(a[col].dtype) for col in a}
538-
a = a.astype(dtypes)
539-
elif isinstance(a, pd.Series):
540-
a = a.astype(_maybe_convert_dtype(a.dtype))
541-
elif isinstance(a, pd.Index) and not isinstance(a, pd.MultiIndex):
542-
a = a.astype(_maybe_convert_dtype(a.dtype))
543-
if (
544-
isinstance(a, (pd.DataFrame, pd.Series))
545-
and not isinstance(a.index, pd.MultiIndex)
546-
and is_object_dtype(a.index)
547-
):
548-
a.index = a.index.astype(_maybe_convert_dtype(a.index))
549-
return a
550-
551-
552533
def assert_eq(
553534
a,
554535
b,
@@ -564,8 +545,10 @@ def assert_eq(
564545
import dask
565546

566547
if dask.config.get("dataframe.object_as_pyarrow_string"):
567-
a = _maybe_convert_to_pyarrow(a)
568-
b = _maybe_convert_to_pyarrow(b)
548+
from dask.dataframe.core import to_pyarrow_string
549+
550+
a = to_pyarrow_string(a)
551+
b = to_pyarrow_string(b)
569552

570553
if check_divisions:
571554
assert_divisions(a, scheduler=scheduler)

0 commit comments

Comments
 (0)