@@ -417,7 +417,7 @@ def _str_isupper(self):
417
417
418
418
def _str_len (self ):
419
419
result = pc .utf8_length (self ._pa_array )
420
- return Int64Dtype (). __from_arrow__ (result )
420
+ return self . _convert_int_dtype (result )
421
421
422
422
def _str_lower (self ):
423
423
return type (self )(pc .utf8_lower (self ._pa_array ))
@@ -446,6 +446,29 @@ def _str_rstrip(self, to_strip=None):
446
446
result = pc .utf8_rtrim (self ._pa_array , characters = to_strip )
447
447
return type (self )(result )
448
448
449
+ def _str_count (self , pat : str , flags : int = 0 ):
450
+ if flags :
451
+ return super ()._str_count (pat , flags )
452
+ result = pc .count_substring_regex (self ._pa_array , pat )
453
+ return self ._convert_int_dtype (result )
454
+
455
+ def _str_find (self , sub : str , start : int = 0 , end : int | None = None ):
456
+ if start != 0 and end is not None :
457
+ slices = pc .utf8_slice_codeunits (self ._pa_array , start , stop = end )
458
+ result = pc .find_substring (slices , sub )
459
+ not_found = pc .equal (result , - 1 )
460
+ offset_result = pc .add (result , end - start )
461
+ result = pc .if_else (not_found , result , offset_result )
462
+ elif start == 0 and end is None :
463
+ slices = self ._pa_array
464
+ result = pc .find_substring (slices , sub )
465
+ else :
466
+ return super ()._str_find (sub , start , end )
467
+ return self ._convert_int_dtype (result )
468
+
469
+ def _convert_int_dtype (self , result ):
470
+ return Int64Dtype ().__from_arrow__ (result )
471
+
449
472
450
473
class ArrowStringArrayNumpySemantics (ArrowStringArray ):
451
474
_storage = "pyarrow_numpy"
@@ -526,34 +549,11 @@ def _str_map(
526
549
return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
527
550
528
551
def _convert_int_dtype (self , result ):
552
+ result = result .to_numpy ()
529
553
if result .dtype == np .int32 :
530
554
result = result .astype (np .int64 )
531
555
return result
532
556
533
- def _str_count (self , pat : str , flags : int = 0 ):
534
- if flags :
535
- return super ()._str_count (pat , flags )
536
- result = pc .count_substring_regex (self ._pa_array , pat ).to_numpy ()
537
- return self ._convert_int_dtype (result )
538
-
539
- def _str_len (self ):
540
- result = pc .utf8_length (self ._pa_array ).to_numpy ()
541
- return self ._convert_int_dtype (result )
542
-
543
- def _str_find (self , sub : str , start : int = 0 , end : int | None = None ):
544
- if start != 0 and end is not None :
545
- slices = pc .utf8_slice_codeunits (self ._pa_array , start , stop = end )
546
- result = pc .find_substring (slices , sub )
547
- not_found = pc .equal (result , - 1 )
548
- offset_result = pc .add (result , end - start )
549
- result = pc .if_else (not_found , result , offset_result )
550
- elif start == 0 and end is None :
551
- slices = self ._pa_array
552
- result = pc .find_substring (slices , sub )
553
- else :
554
- return super ()._str_find (sub , start , end )
555
- return self ._convert_int_dtype (result .to_numpy ())
556
-
557
557
def _cmp_method (self , other , op ):
558
558
result = super ()._cmp_method (other , op )
559
559
return result .to_numpy (np .bool_ , na_value = False )
0 commit comments