@@ -458,7 +458,7 @@ def maybe_cast_pointwise_result(
458
458
"""
459
459
460
460
if isinstance (dtype , ExtensionDtype ):
461
- if not isinstance (dtype , (CategoricalDtype , DatetimeTZDtype )):
461
+ if not isinstance (dtype , (CategoricalDtype , DatetimeTZDtype , ArrowDtype )):
462
462
# TODO: avoid this special-casing
463
463
# We have to special case categorical so as not to upcast
464
464
# things like counts back to categorical
@@ -468,7 +468,17 @@ def maybe_cast_pointwise_result(
468
468
result = _maybe_cast_to_extension_array (cls , result , dtype = dtype )
469
469
else :
470
470
result = _maybe_cast_to_extension_array (cls , result )
471
-
471
+ elif isinstance (dtype , ArrowDtype ):
472
+ pyarrow_type = convert_dtypes (result , dtype_backend = "pyarrow" )
473
+ if isinstance (pyarrow_type , ExtensionDtype ):
474
+ cls = pyarrow_type .construct_array_type ()
475
+ result = _maybe_cast_to_extension_array (cls , result )
476
+ else :
477
+ cls = dtype .construct_array_type ()
478
+ if same_dtype :
479
+ result = _maybe_cast_to_extension_array (cls , result , dtype = dtype )
480
+ else :
481
+ result = _maybe_cast_to_extension_array (cls , result )
472
482
elif (numeric_only and dtype .kind in "iufcb" ) or not numeric_only :
473
483
result = maybe_downcast_to_dtype (result , dtype )
474
484
0 commit comments