@@ -330,17 +330,47 @@ def _str_contains(
330
330
result [isna (result )] = bool (na )
331
331
return result
332
332
333
- def _str_startswith (self , pat : str , na = None ):
334
- result = pc .starts_with (self ._pa_array , pattern = pat )
333
+ def _str_startswith (self , pat : str | tuple [str , ...], na : Scalar | None = None ):
334
+ if isinstance (pat , str ):
335
+ result = pc .starts_with (self ._pa_array , pattern = pat )
336
+ elif isinstance (pat , tuple ) and all (isinstance (x , str ) for x in pat ):
337
+ if len (pat ) == 0 :
338
+ # mimic existing behaviour of string extension array
339
+ # and python string method
340
+ result = pa .array (
341
+ np .full (len (self ._pa_array ), False ), mask = isna (self ._pa_array )
342
+ )
343
+ else :
344
+ result = pc .starts_with (self ._pa_array , pattern = pat [0 ])
345
+
346
+ for p in pat [1 :]:
347
+ result = pc .or_ (result , pc .starts_with (self ._pa_array , pattern = p ))
348
+ else :
349
+ raise TypeError ("pat must be str or tuple[str, ...]" )
335
350
if not isna (na ):
336
351
result = result .fill_null (na )
337
352
result = self ._result_converter (result )
338
353
if not isna (na ):
339
354
result [isna (result )] = bool (na )
340
355
return result
341
356
342
- def _str_endswith (self , pat : str , na = None ):
343
- result = pc .ends_with (self ._pa_array , pattern = pat )
357
+ def _str_endswith (self , pat : str | tuple [str , ...], na : Scalar | None = None ):
358
+ if isinstance (pat , str ):
359
+ result = pc .ends_with (self ._pa_array , pattern = pat )
360
+ elif isinstance (pat , tuple ) and all (isinstance (x , str ) for x in pat ):
361
+ if len (pat ) == 0 :
362
+ # mimic existing behaviour of string extension array
363
+ # and python string method
364
+ result = pa .array (
365
+ np .full (len (self ._pa_array ), False ), mask = isna (self ._pa_array )
366
+ )
367
+ else :
368
+ result = pc .ends_with (self ._pa_array , pattern = pat [0 ])
369
+
370
+ for p in pat [1 :]:
371
+ result = pc .or_ (result , pc .ends_with (self ._pa_array , pattern = p ))
372
+ else :
373
+ raise TypeError ("pat must be of type str or tuple[str, ...]" )
344
374
if not isna (na ):
345
375
result = result .fill_null (na )
346
376
result = self ._result_converter (result )
0 commit comments