1
1
from __future__ import annotations
2
2
3
+ from functools import partial
3
4
import re
4
5
from typing import (
5
6
TYPE_CHECKING ,
27
28
)
28
29
from pandas .core .dtypes .missing import isna
29
30
31
+ from pandas .core .arrays ._arrow_string_mixins import ArrowStringArrayMixin
30
32
from pandas .core .arrays .arrow import ArrowExtensionArray
31
33
from pandas .core .arrays .boolean import BooleanDtype
32
34
from pandas .core .arrays .integer import Int64Dtype
@@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
113
115
# error: Incompatible types in assignment (expression has type "StringDtype",
114
116
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
115
117
_dtype : StringDtype # type: ignore[assignment]
118
+ _storage = "pyarrow"
116
119
117
120
def __init__ (self , values ) -> None :
118
121
super ().__init__ (values )
119
- self ._dtype = StringDtype (storage = "pyarrow" )
122
+ self ._dtype = StringDtype (storage = self . _storage )
120
123
121
124
if not pa .types .is_string (self ._pa_array .type ) and not (
122
125
pa .types .is_dictionary (self ._pa_array .type )
@@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)
144
147
145
148
if dtype and not (isinstance (dtype , str ) and dtype == "string" ):
146
149
dtype = pandas_dtype (dtype )
147
- assert isinstance (dtype , StringDtype ) and dtype .storage == "pyarrow"
150
+ assert isinstance (dtype , StringDtype ) and dtype .storage in (
151
+ "pyarrow" ,
152
+ "pyarrow_numpy" ,
153
+ )
148
154
149
155
if isinstance (scalars , BaseMaskedArray ):
150
156
# avoid costly conversion to object dtype in ensure_string_array and
@@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray:
178
184
raise TypeError ("Scalar must be NA or str" )
179
185
return super ().insert (loc , item )
180
186
187
+ @classmethod
188
+ def _result_converter (cls , values , na = None ):
189
+ return BooleanDtype ().__from_arrow__ (values )
190
+
181
191
def _maybe_convert_setitem_value (self , value ):
182
192
"""Maybe convert value to be pyarrow compatible."""
183
193
if is_scalar (value ):
@@ -313,7 +323,7 @@ def _str_contains(
313
323
result = pc .match_substring_regex (self ._pa_array , pat , ignore_case = not case )
314
324
else :
315
325
result = pc .match_substring (self ._pa_array , pat , ignore_case = not case )
316
- result = BooleanDtype (). __from_arrow__ (result )
326
+ result = self . _result_converter (result , na = na )
317
327
if not isna (na ):
318
328
result [isna (result )] = bool (na )
319
329
return result
@@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None):
322
332
result = pc .starts_with (self ._pa_array , pattern = pat )
323
333
if not isna (na ):
324
334
result = result .fill_null (na )
325
- result = BooleanDtype (). __from_arrow__ (result )
335
+ result = self . _result_converter (result )
326
336
if not isna (na ):
327
337
result [isna (result )] = bool (na )
328
338
return result
@@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None):
331
341
result = pc .ends_with (self ._pa_array , pattern = pat )
332
342
if not isna (na ):
333
343
result = result .fill_null (na )
334
- result = BooleanDtype (). __from_arrow__ (result )
344
+ result = self . _result_converter (result )
335
345
if not isna (na ):
336
346
result [isna (result )] = bool (na )
337
347
return result
@@ -369,39 +379,39 @@ def _str_fullmatch(
369
379
370
380
def _str_isalnum (self ):
371
381
result = pc .utf8_is_alnum (self ._pa_array )
372
- return BooleanDtype (). __from_arrow__ (result )
382
+ return self . _result_converter (result )
373
383
374
384
def _str_isalpha (self ):
375
385
result = pc .utf8_is_alpha (self ._pa_array )
376
- return BooleanDtype (). __from_arrow__ (result )
386
+ return self . _result_converter (result )
377
387
378
388
def _str_isdecimal (self ):
379
389
result = pc .utf8_is_decimal (self ._pa_array )
380
- return BooleanDtype (). __from_arrow__ (result )
390
+ return self . _result_converter (result )
381
391
382
392
def _str_isdigit (self ):
383
393
result = pc .utf8_is_digit (self ._pa_array )
384
- return BooleanDtype (). __from_arrow__ (result )
394
+ return self . _result_converter (result )
385
395
386
396
def _str_islower (self ):
387
397
result = pc .utf8_is_lower (self ._pa_array )
388
- return BooleanDtype (). __from_arrow__ (result )
398
+ return self . _result_converter (result )
389
399
390
400
def _str_isnumeric (self ):
391
401
result = pc .utf8_is_numeric (self ._pa_array )
392
- return BooleanDtype (). __from_arrow__ (result )
402
+ return self . _result_converter (result )
393
403
394
404
def _str_isspace (self ):
395
405
result = pc .utf8_is_space (self ._pa_array )
396
- return BooleanDtype (). __from_arrow__ (result )
406
+ return self . _result_converter (result )
397
407
398
408
def _str_istitle (self ):
399
409
result = pc .utf8_is_title (self ._pa_array )
400
- return BooleanDtype (). __from_arrow__ (result )
410
+ return self . _result_converter (result )
401
411
402
412
def _str_isupper (self ):
403
413
result = pc .utf8_is_upper (self ._pa_array )
404
- return BooleanDtype (). __from_arrow__ (result )
414
+ return self . _result_converter (result )
405
415
406
416
def _str_len (self ):
407
417
result = pc .utf8_length (self ._pa_array )
@@ -433,3 +443,114 @@ def _str_rstrip(self, to_strip=None):
433
443
else :
434
444
result = pc .utf8_rtrim (self ._pa_array , characters = to_strip )
435
445
return type (self )(result )
446
+
447
+
448
+ class ArrowStringArrayNumpySemantics (ArrowStringArray ):
449
+ _storage = "pyarrow_numpy"
450
+
451
+ @classmethod
452
+ def _result_converter (cls , values , na = None ):
453
+ if not isna (na ):
454
+ values = values .fill_null (bool (na ))
455
+ return ArrowExtensionArray (values ).to_numpy (na_value = np .nan )
456
+
457
+ def __getattribute__ (self , item ):
458
+ # ArrowStringArray and we both inherit from ArrowExtensionArray, which
459
+ # creates inheritance problems (Diamond inheritance)
460
+ if item in ArrowStringArrayMixin .__dict__ and item != "_pa_array" :
461
+ return partial (getattr (ArrowStringArrayMixin , item ), self )
462
+ return super ().__getattribute__ (item )
463
+
464
+ def _str_map (
465
+ self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
466
+ ):
467
+ if dtype is None :
468
+ dtype = self .dtype
469
+ if na_value is None :
470
+ na_value = self .dtype .na_value
471
+
472
+ mask = isna (self )
473
+ arr = np .asarray (self )
474
+
475
+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
476
+ if is_integer_dtype (dtype ):
477
+ na_value = np .nan
478
+ else :
479
+ na_value = False
480
+ try :
481
+ result = lib .map_infer_mask (
482
+ arr ,
483
+ f ,
484
+ mask .view ("uint8" ),
485
+ convert = False ,
486
+ na_value = na_value ,
487
+ dtype = np .dtype (dtype ), # type: ignore[arg-type]
488
+ )
489
+ return result
490
+
491
+ except ValueError :
492
+ result = lib .map_infer_mask (
493
+ arr ,
494
+ f ,
495
+ mask .view ("uint8" ),
496
+ convert = False ,
497
+ na_value = na_value ,
498
+ )
499
+ if convert and result .dtype == object :
500
+ result = lib .maybe_convert_objects (result )
501
+ return result
502
+
503
+ elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
504
+ # i.e. StringDtype
505
+ result = lib .map_infer_mask (
506
+ arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
507
+ )
508
+ result = pa .array (result , mask = mask , type = pa .string (), from_pandas = True )
509
+ return type (self )(result )
510
+ else :
511
+ # This is when the result type is object. We reach this when
512
+ # -> We know the result type is truly object (e.g. .encode returns bytes
513
+ # or .findall returns a list).
514
+ # -> We don't know the result type. E.g. `.get` can return anything.
515
+ return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
516
+
517
+ def _convert_int_dtype (self , result ):
518
+ if result .dtype == np .int32 :
519
+ result = result .astype (np .int64 )
520
+ return result
521
+
522
+ def _str_count (self , pat : str , flags : int = 0 ):
523
+ if flags :
524
+ return super ()._str_count (pat , flags )
525
+ result = pc .count_substring_regex (self ._pa_array , pat ).to_numpy ()
526
+ return self ._convert_int_dtype (result )
527
+
528
+ def _str_len (self ):
529
+ result = pc .utf8_length (self ._pa_array ).to_numpy ()
530
+ return self ._convert_int_dtype (result )
531
+
532
+ def _str_find (self , sub : str , start : int = 0 , end : int | None = None ):
533
+ if start != 0 and end is not None :
534
+ slices = pc .utf8_slice_codeunits (self ._pa_array , start , stop = end )
535
+ result = pc .find_substring (slices , sub )
536
+ not_found = pc .equal (result , - 1 )
537
+ offset_result = pc .add (result , end - start )
538
+ result = pc .if_else (not_found , result , offset_result )
539
+ elif start == 0 and end is None :
540
+ slices = self ._pa_array
541
+ result = pc .find_substring (slices , sub )
542
+ else :
543
+ return super ()._str_find (sub , start , end )
544
+ return self ._convert_int_dtype (result .to_numpy ())
545
+
546
+ def _cmp_method (self , other , op ):
547
+ result = super ()._cmp_method (other , op )
548
+ return result .to_numpy (np .bool_ , na_value = False )
549
+
550
+ def value_counts (self , dropna : bool = True ):
551
+ from pandas import Series
552
+
553
+ result = super ().value_counts (dropna )
554
+ return Series (
555
+ result ._values .to_numpy (), index = result .index , name = result .name , copy = False
556
+ )
0 commit comments