|
15 | 15 | is_bool_dtype,
|
16 | 16 | is_datetime64_any_dtype,
|
17 | 17 | is_numeric_dtype,
|
18 |
| - is_object_dtype, |
| 18 | + is_string_dtype, |
19 | 19 | is_timedelta64_dtype,
|
20 | 20 | )
|
21 | 21 | from tlz import first, merge, partition_all, remove, unique
|
@@ -323,38 +323,68 @@ def _scalar_binary(op, self, other, inv=False):
|
323 | 323 | return Scalar(graph, name, meta)
|
324 | 324 |
|
325 | 325 |
|
326 |
| -def _is_object_or_string_dtype(dtype): |
327 |
| - """Determine if input dtype is `object` or `string[python]`""" |
328 |
| - if is_object_dtype(dtype) or ( |
329 |
| - isinstance(dtype, pd.StringDtype) and dtype.storage == "python" |
330 |
| - ): |
| 326 | +def _is_pyarrow_string(dtype): |
| 327 | + if not PANDAS_GT_130: |
| 328 | + return False |
| 329 | + |
| 330 | + if PANDAS_GT_150: |
| 331 | + import pyarrow as pa |
| 332 | + |
| 333 | + types = [pd.StringDtype("pyarrow"), pd.ArrowDtype(pa.string())] |
| 334 | + else: |
| 335 | + types = [pd.StringDtype("pyarrow")] |
| 336 | + if dtype in types: |
331 | 337 | return True
|
332 | 338 | return False
|
333 | 339 |
|
334 | 340 |
|
| 341 | +def _is_object_string_dtype(dtype): |
| 342 | + """Determine if input is a non-pyarrow string dtype""" |
| 343 | + return is_string_dtype(dtype) and not _is_pyarrow_string(dtype) |
| 344 | + |
| 345 | + |
| 346 | +def _index_check(x): |
| 347 | + return ( |
| 348 | + is_index_like(x) |
| 349 | + and _is_object_string_dtype(x.dtype) |
| 350 | + and not isinstance( |
| 351 | + x, pd.MultiIndex |
| 352 | + ) # MultiIndex don't support non-object dtypes |
| 353 | + ) |
| 354 | + |
| 355 | + |
| 356 | +def _series_check(x): |
| 357 | + return is_series_like(x) and ( |
| 358 | + _is_object_string_dtype(x.dtype) or _index_check(x.index) |
| 359 | + ) |
| 360 | + |
| 361 | + |
| 362 | +def _dataframe_check(x): |
| 363 | + return is_dataframe_like(x) and ( |
| 364 | + any(_series_check(s) for _, s in x.items()) or _index_check(x.index) |
| 365 | + ) |
| 366 | + |
| 367 | + |
335 | 368 | def to_pyarrow_string(df):
|
336 | 369 | if not (is_dataframe_like(df) or is_series_like(df) or is_index_like(df)):
|
337 | 370 | return df
|
338 | 371 |
|
| 372 | + # Possibly convert DataFrame/Series/Index to `string[pyarrow]` |
339 | 373 | dtypes = None
|
340 | 374 | if is_dataframe_like(df):
|
341 | 375 | dtypes = {
|
342 | 376 | col: pd.StringDtype("pyarrow")
|
343 | 377 | for col, s in df.items()
|
344 |
| - if _is_object_or_string_dtype(s.dtype) |
| 378 | + if _is_object_string_dtype(s.dtype) |
345 | 379 | }
|
346 |
| - elif _is_object_or_string_dtype(df.dtype): |
| 380 | + elif _is_object_string_dtype(df.dtype): |
347 | 381 | dtypes = pd.StringDtype("pyarrow")
|
348 | 382 |
|
349 | 383 | if dtypes is not None:
|
350 | 384 | df = df.astype(dtypes)
|
351 | 385 |
|
352 |
| - # Convert DataFrame and Series index too |
353 |
| - if ( |
354 |
| - (is_dataframe_like(df) or is_series_like(df)) |
355 |
| - and not isinstance(df.index, pd.MultiIndex) |
356 |
| - and _is_object_or_string_dtype(df.index.dtype) |
357 |
| - ): |
| 386 | + # Convert DataFrame/Series index too |
| 387 | + if (is_dataframe_like(df) or is_series_like(df)) and _index_check(df.index): |
358 | 388 | df.index = df.index.astype(pd.StringDtype("pyarrow"))
|
359 | 389 | return df
|
360 | 390 |
|
@@ -406,25 +436,6 @@ def __init__(self, dsk, name, meta, divisions):
|
406 | 436 | f"pandas={str(PANDAS_VERSION)} is currently using used."
|
407 | 437 | )
|
408 | 438 |
|
409 |
| - def _index_check(x): |
410 |
| - return ( |
411 |
| - is_index_like(x) |
412 |
| - and _is_object_or_string_dtype(x) |
413 |
| - and not isinstance( |
414 |
| - x, pd.MultiIndex |
415 |
| - ) # MultiIndex don't support non-object dtypes |
416 |
| - ) |
417 |
| - |
418 |
| - def _series_check(x): |
419 |
| - return is_series_like(x) and ( |
420 |
| - _is_object_or_string_dtype(x) or _index_check(x.index) |
421 |
| - ) |
422 |
| - |
423 |
| - def _dataframe_check(x): |
424 |
| - return is_dataframe_like(x) and ( |
425 |
| - any(_series_check(s) for _, s in x.items()) or _index_check(x.index) |
426 |
| - ) |
427 |
| - |
428 | 439 | if _dataframe_check(meta) or _series_check(meta) or _index_check(meta):
|
429 | 440 | result = self.map_partitions(to_pyarrow_string)
|
430 | 441 | self.dask = result.dask
|
|
0 commit comments