16
16
time ,
17
17
timedelta ,
18
18
)
19
+ from decimal import Decimal
19
20
from io import (
20
21
BytesIO ,
21
22
StringIO ,
@@ -79,6 +80,14 @@ def data(dtype):
79
80
data = [1 , 0 ] * 4 + [None ] + [- 2 , - 1 ] * 44 + [None ] + [1 , 99 ]
80
81
elif pa .types .is_unsigned_integer (pa_dtype ):
81
82
data = [1 , 0 ] * 4 + [None ] + [2 , 1 ] * 44 + [None ] + [1 , 99 ]
83
+ elif pa .types .is_decimal (pa_dtype ):
84
+ data = (
85
+ [Decimal ("1" ), Decimal ("0.0" )] * 4
86
+ + [None ]
87
+ + [Decimal ("-2.0" ), Decimal ("-1.0" )] * 44
88
+ + [None ]
89
+ + [Decimal ("0.5" ), Decimal ("33.123" )]
90
+ )
82
91
elif pa .types .is_date (pa_dtype ):
83
92
data = (
84
93
[date (2022 , 1 , 1 ), date (1999 , 12 , 31 )] * 4
@@ -188,6 +197,10 @@ def data_for_grouping(dtype):
188
197
A = b"a"
189
198
B = b"b"
190
199
C = b"c"
200
+ elif pa .types .is_decimal (pa_dtype ):
201
+ A = Decimal ("-1.1" )
202
+ B = Decimal ("0.0" )
203
+ C = Decimal ("1.1" )
191
204
else :
192
205
raise NotImplementedError
193
206
return pd .array ([B , B , None , None , A , A , B , C ], dtype = dtype )
@@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
250
263
class TestConstructors (base .BaseConstructorsTests ):
251
264
def test_from_dtype (self , data , request ):
252
265
pa_dtype = data .dtype .pyarrow_dtype
266
+ if pa .types .is_string (pa_dtype ) or pa .types .is_decimal (pa_dtype ):
267
+ if pa .types .is_string (pa_dtype ):
268
+ reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
269
+ else :
270
+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } "
253
271
254
- if pa .types .is_string (pa_dtype ):
255
- reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
256
272
request .node .add_marker (
257
273
pytest .mark .xfail (
258
274
reason = reason ,
259
275
)
260
276
)
261
277
super ().test_from_dtype (data )
262
278
263
- def test_from_sequence_pa_array (self , data , request ):
279
+ def test_from_sequence_pa_array (self , data ):
264
280
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
265
281
# data._data = pa.ChunkedArray
266
282
result = type (data )._from_sequence (data ._data )
@@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
285
301
reason = "Nanosecond time parsing not supported." ,
286
302
)
287
303
)
288
- elif pa_version_under11p0 and pa .types .is_duration (pa_dtype ):
304
+ elif pa_version_under11p0 and (
305
+ pa .types .is_duration (pa_dtype ) or pa .types .is_decimal (pa_dtype )
306
+ ):
289
307
request .node .add_marker (
290
308
pytest .mark .xfail (
291
309
raises = pa .ArrowNotImplementedError ,
@@ -384,7 +402,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
384
402
# renders the exception messages even when not showing them
385
403
pytest .skip (f"{ all_numeric_accumulations } not implemented for pyarrow < 9" )
386
404
387
- elif all_numeric_accumulations == "cumsum" and pa .types .is_boolean (pa_type ):
405
+ elif all_numeric_accumulations == "cumsum" and (
406
+ pa .types .is_boolean (pa_type ) or pa .types .is_decimal (pa_type )
407
+ ):
388
408
request .node .add_marker (
389
409
pytest .mark .xfail (
390
410
reason = f"{ all_numeric_accumulations } not implemented for { pa_type } " ,
@@ -468,6 +488,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
468
488
)
469
489
if all_numeric_reductions in {"skew" , "kurt" }:
470
490
request .node .add_marker (xfail_mark )
491
+ elif (
492
+ all_numeric_reductions in {"var" , "std" , "median" }
493
+ and pa_version_under7p0
494
+ and pa .types .is_decimal (pa_dtype )
495
+ ):
496
+ request .node .add_marker (xfail_mark )
471
497
elif all_numeric_reductions == "sem" and pa_version_under8p0 :
472
498
request .node .add_marker (xfail_mark )
473
499
@@ -590,8 +616,26 @@ def test_in_numeric_groupby(self, data_for_grouping):
590
616
591
617
592
618
class TestBaseDtype (base .BaseDtypeTests ):
619
+ def test_check_dtype (self , data , request ):
620
+ pa_dtype = data .dtype .pyarrow_dtype
621
+ if pa .types .is_decimal (pa_dtype ) and pa_version_under8p0 :
622
+ request .node .add_marker (
623
+ pytest .mark .xfail (
624
+ raises = ValueError ,
625
+ reason = "decimal string repr affects numpy comparison" ,
626
+ )
627
+ )
628
+ super ().test_check_dtype (data )
629
+
593
630
def test_construct_from_string_own_name (self , dtype , request ):
594
631
pa_dtype = dtype .pyarrow_dtype
632
+ if pa .types .is_decimal (pa_dtype ):
633
+ request .node .add_marker (
634
+ pytest .mark .xfail (
635
+ raises = NotImplementedError ,
636
+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } " ,
637
+ )
638
+ )
595
639
596
640
if pa .types .is_string (pa_dtype ):
597
641
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
@@ -609,6 +653,13 @@ def test_is_dtype_from_name(self, dtype, request):
609
653
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
610
654
assert not type (dtype ).is_dtype (dtype .name )
611
655
else :
656
+ if pa .types .is_decimal (pa_dtype ):
657
+ request .node .add_marker (
658
+ pytest .mark .xfail (
659
+ raises = NotImplementedError ,
660
+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } " ,
661
+ )
662
+ )
612
663
super ().test_is_dtype_from_name (dtype )
613
664
614
665
def test_construct_from_string_another_type_raises (self , dtype ):
@@ -627,6 +678,7 @@ def test_get_common_dtype(self, dtype, request):
627
678
)
628
679
or (pa .types .is_duration (pa_dtype ) and pa_dtype .unit != "ns" )
629
680
or pa .types .is_binary (pa_dtype )
681
+ or pa .types .is_decimal (pa_dtype )
630
682
):
631
683
request .node .add_marker (
632
684
pytest .mark .xfail (
@@ -700,6 +752,13 @@ def test_EA_types(self, engine, data, request):
700
752
request .node .add_marker (
701
753
pytest .mark .xfail (raises = TypeError , reason = "GH 47534" )
702
754
)
755
+ elif pa .types .is_decimal (pa_dtype ):
756
+ request .node .add_marker (
757
+ pytest .mark .xfail (
758
+ raises = NotImplementedError ,
759
+ reason = f"Parameterized types { pa_dtype } not supported." ,
760
+ )
761
+ )
703
762
elif pa .types .is_timestamp (pa_dtype ) and pa_dtype .unit in ("us" , "ns" ):
704
763
request .node .add_marker (
705
764
pytest .mark .xfail (
@@ -782,6 +841,13 @@ def test_argmin_argmax(
782
841
reason = f"{ pa_dtype } only has 2 unique possible values" ,
783
842
)
784
843
)
844
+ elif pa .types .is_decimal (pa_dtype ) and pa_version_under7p0 :
845
+ request .node .add_marker (
846
+ pytest .mark .xfail (
847
+ reason = f"No pyarrow kernel for { pa_dtype } " ,
848
+ raises = pa .ArrowNotImplementedError ,
849
+ )
850
+ )
785
851
super ().test_argmin_argmax (data_for_sorting , data_missing_for_sorting , na_value )
786
852
787
853
@pytest .mark .parametrize (
@@ -800,6 +866,14 @@ def test_argmin_argmax(
800
866
def test_argreduce_series (
801
867
self , data_missing_for_sorting , op_name , skipna , expected , request
802
868
):
869
+ pa_dtype = data_missing_for_sorting .dtype .pyarrow_dtype
870
+ if pa .types .is_decimal (pa_dtype ) and pa_version_under7p0 and skipna :
871
+ request .node .add_marker (
872
+ pytest .mark .xfail (
873
+ reason = f"No pyarrow kernel for { pa_dtype } " ,
874
+ raises = pa .ArrowNotImplementedError ,
875
+ )
876
+ )
803
877
super ().test_argreduce_series (
804
878
data_missing_for_sorting , op_name , skipna , expected
805
879
)
@@ -888,6 +962,21 @@ def test_basic_equals(self, data):
888
962
class TestBaseArithmeticOps (base .BaseArithmeticOpsTests ):
889
963
divmod_exc = NotImplementedError
890
964
965
+ @classmethod
966
+ def assert_equal (cls , left , right , ** kwargs ):
967
+ if isinstance (left , pd .DataFrame ):
968
+ left_pa_type = left .iloc [:, 0 ].dtype .pyarrow_dtype
969
+ right_pa_type = right .iloc [:, 0 ].dtype .pyarrow_dtype
970
+ else :
971
+ left_pa_type = left .dtype .pyarrow_dtype
972
+ right_pa_type = right .dtype .pyarrow_dtype
973
+ if pa .types .is_decimal (left_pa_type ) or pa .types .is_decimal (right_pa_type ):
974
+ # decimal precision can resize in the result type depending on data
975
+ # just compare the float values
976
+ left = left .astype ("float[pyarrow]" )
977
+ right = right .astype ("float[pyarrow]" )
978
+ tm .assert_equal (left , right , ** kwargs )
979
+
891
980
def get_op_from_name (self , op_name ):
892
981
short_opname = op_name .strip ("_" )
893
982
if short_opname == "rtruediv" :
@@ -967,7 +1056,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
967
1056
pa .types .is_string (pa_dtype ) or pa .types .is_binary (pa_dtype )
968
1057
):
969
1058
exc = None
970
- elif not (pa .types .is_floating (pa_dtype ) or pa .types .is_integer (pa_dtype )):
1059
+ elif not (
1060
+ pa .types .is_floating (pa_dtype )
1061
+ or pa .types .is_integer (pa_dtype )
1062
+ or pa .types .is_decimal (pa_dtype )
1063
+ ):
971
1064
exc = pa .ArrowNotImplementedError
972
1065
else :
973
1066
exc = None
@@ -980,7 +1073,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
980
1073
981
1074
if (
982
1075
opname == "__rpow__"
983
- and (pa .types .is_floating (pa_dtype ) or pa .types .is_integer (pa_dtype ))
1076
+ and (
1077
+ pa .types .is_floating (pa_dtype )
1078
+ or pa .types .is_integer (pa_dtype )
1079
+ or pa .types .is_decimal (pa_dtype )
1080
+ )
984
1081
and not pa_version_under7p0
985
1082
):
986
1083
mark = pytest .mark .xfail (
@@ -998,14 +1095,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
998
1095
),
999
1096
)
1000
1097
elif (
1001
- opname in { "__rfloordiv__" }
1002
- and pa .types .is_integer (pa_dtype )
1098
+ opname == "__rfloordiv__"
1099
+ and ( pa .types .is_integer (pa_dtype ) or pa . types . is_decimal ( pa_dtype ) )
1003
1100
and not pa_version_under7p0
1004
1101
):
1005
1102
mark = pytest .mark .xfail (
1006
1103
raises = pa .ArrowInvalid ,
1007
1104
reason = "divide by 0" ,
1008
1105
)
1106
+ elif (
1107
+ opname == "__rtruediv__"
1108
+ and pa .types .is_decimal (pa_dtype )
1109
+ and not pa_version_under7p0
1110
+ ):
1111
+ mark = pytest .mark .xfail (
1112
+ raises = pa .ArrowInvalid ,
1113
+ reason = "divide by 0" ,
1114
+ )
1115
+ elif (
1116
+ opname == "__pow__"
1117
+ and pa .types .is_decimal (pa_dtype )
1118
+ and pa_version_under7p0
1119
+ ):
1120
+ mark = pytest .mark .xfail (
1121
+ raises = pa .ArrowInvalid ,
1122
+ reason = "Invalid decimal function: power_checked" ,
1123
+ )
1009
1124
1010
1125
return mark
1011
1126
@@ -1226,6 +1341,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
1226
1341
expected = ArrowDtype (pa .timestamp ("s" , "UTC" ))
1227
1342
assert dtype == expected
1228
1343
1344
+ with pytest .raises (NotImplementedError , match = "Passing pyarrow type" ):
1345
+ ArrowDtype .construct_from_string ("decimal(7, 2)[pyarrow]" )
1346
+
1229
1347
1230
1348
@pytest .mark .parametrize (
1231
1349
"interpolation" , ["linear" , "lower" , "higher" , "nearest" , "midpoint" ]
@@ -1252,7 +1370,11 @@ def test_quantile(data, interpolation, quantile, request):
1252
1370
ser .quantile (q = quantile , interpolation = interpolation )
1253
1371
return
1254
1372
1255
- if pa .types .is_integer (pa_dtype ) or pa .types .is_floating (pa_dtype ):
1373
+ if (
1374
+ pa .types .is_integer (pa_dtype )
1375
+ or pa .types .is_floating (pa_dtype )
1376
+ or (pa .types .is_decimal (pa_dtype ) and not pa_version_under7p0 )
1377
+ ):
1256
1378
pass
1257
1379
elif pa .types .is_temporal (data ._data .type ):
1258
1380
pass
@@ -1293,7 +1415,11 @@ def test_quantile(data, interpolation, quantile, request):
1293
1415
else :
1294
1416
# Just check the values
1295
1417
expected = pd .Series (data .take ([0 , 0 ]), index = [0.5 , 0.5 ])
1296
- if pa .types .is_integer (pa_dtype ) or pa .types .is_floating (pa_dtype ):
1418
+ if (
1419
+ pa .types .is_integer (pa_dtype )
1420
+ or pa .types .is_floating (pa_dtype )
1421
+ or pa .types .is_decimal (pa_dtype )
1422
+ ):
1297
1423
expected = expected .astype ("float64[pyarrow]" )
1298
1424
result = result .astype ("float64[pyarrow]" )
1299
1425
tm .assert_series_equal (result , expected )
0 commit comments