16
16
time ,
17
17
timedelta ,
18
18
)
19
+ from decimal import Decimal
19
20
from io import (
20
21
BytesIO ,
21
22
StringIO ,
@@ -79,6 +80,14 @@ def data(dtype):
79
80
data = [1 , 0 ] * 4 + [None ] + [- 2 , - 1 ] * 44 + [None ] + [1 , 99 ]
80
81
elif pa .types .is_unsigned_integer (pa_dtype ):
81
82
data = [1 , 0 ] * 4 + [None ] + [2 , 1 ] * 44 + [None ] + [1 , 99 ]
83
+ elif pa .types .is_decimal (pa_dtype ):
84
+ data = (
85
+ [Decimal ("1" ), Decimal ("0.0" )] * 4
86
+ + [None ]
87
+ + [Decimal ("-2.0" ), Decimal ("-1.0" )] * 44
88
+ + [None ]
89
+ + [Decimal ("0.5" ), Decimal ("33.123" )]
90
+ )
82
91
elif pa .types .is_date (pa_dtype ):
83
92
data = (
84
93
[date (2022 , 1 , 1 ), date (1999 , 12 , 31 )] * 4
@@ -188,6 +197,10 @@ def data_for_grouping(dtype):
188
197
A = b"a"
189
198
B = b"b"
190
199
C = b"c"
200
+ elif pa .types .is_decimal (pa_dtype ):
201
+ A = Decimal ("-1.1" )
202
+ B = Decimal ("0.0" )
203
+ C = Decimal ("1.1" )
191
204
else :
192
205
raise NotImplementedError
193
206
return pd .array ([B , B , None , None , A , A , B , C ], dtype = dtype )
@@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
250
263
class TestConstructors (base .BaseConstructorsTests ):
251
264
def test_from_dtype (self , data , request ):
252
265
pa_dtype = data .dtype .pyarrow_dtype
266
+ if pa .types .is_string (pa_dtype ) or pa .types .is_decimal (pa_dtype ):
267
+ if pa .types .is_string (pa_dtype ):
268
+ reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
269
+ else :
270
+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } "
253
271
254
- if pa .types .is_string (pa_dtype ):
255
- reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
256
272
request .node .add_marker (
257
273
pytest .mark .xfail (
258
274
reason = reason ,
259
275
)
260
276
)
261
277
super ().test_from_dtype (data )
262
278
263
- def test_from_sequence_pa_array (self , data , request ):
279
+ def test_from_sequence_pa_array (self , data ):
264
280
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
265
281
# data._data = pa.ChunkedArray
266
282
result = type (data )._from_sequence (data ._data )
@@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
285
301
reason = "Nanosecond time parsing not supported." ,
286
302
)
287
303
)
288
- elif pa_version_under11p0 and pa .types .is_duration (pa_dtype ):
304
+ elif pa_version_under11p0 and (
305
+ pa .types .is_duration (pa_dtype ) or pa .types .is_decimal (pa_dtype )
306
+ ):
289
307
request .node .add_marker (
290
308
pytest .mark .xfail (
291
309
raises = pa .ArrowNotImplementedError ,
@@ -392,7 +410,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
392
410
raises = NotImplementedError ,
393
411
)
394
412
)
395
- elif all_numeric_accumulations == "cumsum" and (pa .types .is_boolean (pa_type )):
413
+ elif all_numeric_accumulations == "cumsum" and (
414
+ pa .types .is_boolean (pa_type ) or pa .types .is_decimal (pa_type )
415
+ ):
396
416
request .node .add_marker (
397
417
pytest .mark .xfail (
398
418
reason = f"{ all_numeric_accumulations } not implemented for { pa_type } " ,
@@ -476,6 +496,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
476
496
)
477
497
if all_numeric_reductions in {"skew" , "kurt" }:
478
498
request .node .add_marker (xfail_mark )
499
+ elif (
500
+ all_numeric_reductions in {"var" , "std" , "median" }
501
+ and pa_version_under7p0
502
+ and pa .types .is_decimal (pa_dtype )
503
+ ):
504
+ request .node .add_marker (xfail_mark )
479
505
elif all_numeric_reductions == "sem" and pa_version_under8p0 :
480
506
request .node .add_marker (xfail_mark )
481
507
@@ -598,8 +624,26 @@ def test_in_numeric_groupby(self, data_for_grouping):
598
624
599
625
600
626
class TestBaseDtype (base .BaseDtypeTests ):
627
+ def test_check_dtype (self , data , request ):
628
+ pa_dtype = data .dtype .pyarrow_dtype
629
+ if pa .types .is_decimal (pa_dtype ) and pa_version_under8p0 :
630
+ request .node .add_marker (
631
+ pytest .mark .xfail (
632
+ raises = ValueError ,
633
+ reason = "decimal string repr affects numpy comparison" ,
634
+ )
635
+ )
636
+ super ().test_check_dtype (data )
637
+
601
638
def test_construct_from_string_own_name (self , dtype , request ):
602
639
pa_dtype = dtype .pyarrow_dtype
640
+ if pa .types .is_decimal (pa_dtype ):
641
+ request .node .add_marker (
642
+ pytest .mark .xfail (
643
+ raises = NotImplementedError ,
644
+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } " ,
645
+ )
646
+ )
603
647
604
648
if pa .types .is_string (pa_dtype ):
605
649
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
@@ -617,6 +661,13 @@ def test_is_dtype_from_name(self, dtype, request):
617
661
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
618
662
assert not type (dtype ).is_dtype (dtype .name )
619
663
else :
664
+ if pa .types .is_decimal (pa_dtype ):
665
+ request .node .add_marker (
666
+ pytest .mark .xfail (
667
+ raises = NotImplementedError ,
668
+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } " ,
669
+ )
670
+ )
620
671
super ().test_is_dtype_from_name (dtype )
621
672
622
673
def test_construct_from_string_another_type_raises (self , dtype ):
@@ -635,6 +686,7 @@ def test_get_common_dtype(self, dtype, request):
635
686
)
636
687
or (pa .types .is_duration (pa_dtype ) and pa_dtype .unit != "ns" )
637
688
or pa .types .is_binary (pa_dtype )
689
+ or pa .types .is_decimal (pa_dtype )
638
690
):
639
691
request .node .add_marker (
640
692
pytest .mark .xfail (
@@ -708,6 +760,13 @@ def test_EA_types(self, engine, data, request):
708
760
request .node .add_marker (
709
761
pytest .mark .xfail (raises = TypeError , reason = "GH 47534" )
710
762
)
763
+ elif pa .types .is_decimal (pa_dtype ):
764
+ request .node .add_marker (
765
+ pytest .mark .xfail (
766
+ raises = NotImplementedError ,
767
+ reason = f"Parameterized types { pa_dtype } not supported." ,
768
+ )
769
+ )
711
770
elif pa .types .is_timestamp (pa_dtype ) and pa_dtype .unit in ("us" , "ns" ):
712
771
request .node .add_marker (
713
772
pytest .mark .xfail (
@@ -790,6 +849,13 @@ def test_argmin_argmax(
790
849
reason = f"{ pa_dtype } only has 2 unique possible values" ,
791
850
)
792
851
)
852
+ elif pa .types .is_decimal (pa_dtype ) and pa_version_under7p0 :
853
+ request .node .add_marker (
854
+ pytest .mark .xfail (
855
+ reason = f"No pyarrow kernel for { pa_dtype } " ,
856
+ raises = pa .ArrowNotImplementedError ,
857
+ )
858
+ )
793
859
super ().test_argmin_argmax (data_for_sorting , data_missing_for_sorting , na_value )
794
860
795
861
@pytest .mark .parametrize (
@@ -808,6 +874,14 @@ def test_argmin_argmax(
808
874
def test_argreduce_series (
809
875
self , data_missing_for_sorting , op_name , skipna , expected , request
810
876
):
877
+ pa_dtype = data_missing_for_sorting .dtype .pyarrow_dtype
878
+ if pa .types .is_decimal (pa_dtype ) and pa_version_under7p0 and skipna :
879
+ request .node .add_marker (
880
+ pytest .mark .xfail (
881
+ reason = f"No pyarrow kernel for { pa_dtype } " ,
882
+ raises = pa .ArrowNotImplementedError ,
883
+ )
884
+ )
811
885
super ().test_argreduce_series (
812
886
data_missing_for_sorting , op_name , skipna , expected
813
887
)
@@ -906,6 +980,21 @@ def test_basic_equals(self, data):
906
980
class TestBaseArithmeticOps (base .BaseArithmeticOpsTests ):
907
981
divmod_exc = NotImplementedError
908
982
983
+ @classmethod
984
+ def assert_equal (cls , left , right , ** kwargs ):
985
+ if isinstance (left , pd .DataFrame ):
986
+ left_pa_type = left .iloc [:, 0 ].dtype .pyarrow_dtype
987
+ right_pa_type = right .iloc [:, 0 ].dtype .pyarrow_dtype
988
+ else :
989
+ left_pa_type = left .dtype .pyarrow_dtype
990
+ right_pa_type = right .dtype .pyarrow_dtype
991
+ if pa .types .is_decimal (left_pa_type ) or pa .types .is_decimal (right_pa_type ):
992
+ # decimal precision can resize in the result type depending on data
993
+ # just compare the float values
994
+ left = left .astype ("float[pyarrow]" )
995
+ right = right .astype ("float[pyarrow]" )
996
+ tm .assert_equal (left , right , ** kwargs )
997
+
909
998
def get_op_from_name (self , op_name ):
910
999
short_opname = op_name .strip ("_" )
911
1000
if short_opname == "rtruediv" :
@@ -975,7 +1064,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
975
1064
pa .types .is_string (pa_dtype ) or pa .types .is_binary (pa_dtype )
976
1065
):
977
1066
exc = None
978
- elif not (pa .types .is_floating (pa_dtype ) or pa .types .is_integer (pa_dtype )):
1067
+ elif not (
1068
+ pa .types .is_floating (pa_dtype )
1069
+ or pa .types .is_integer (pa_dtype )
1070
+ or pa .types .is_decimal (pa_dtype )
1071
+ ):
979
1072
exc = pa .ArrowNotImplementedError
980
1073
else :
981
1074
exc = None
@@ -988,7 +1081,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
988
1081
989
1082
if (
990
1083
opname == "__rpow__"
991
- and (pa .types .is_floating (pa_dtype ) or pa .types .is_integer (pa_dtype ))
1084
+ and (
1085
+ pa .types .is_floating (pa_dtype )
1086
+ or pa .types .is_integer (pa_dtype )
1087
+ or pa .types .is_decimal (pa_dtype )
1088
+ )
992
1089
and not pa_version_under7p0
993
1090
):
994
1091
mark = pytest .mark .xfail (
@@ -1006,14 +1103,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
1006
1103
),
1007
1104
)
1008
1105
elif (
1009
- opname in { "__rfloordiv__" }
1010
- and pa .types .is_integer (pa_dtype )
1106
+ opname == "__rfloordiv__"
1107
+ and ( pa .types .is_integer (pa_dtype ) or pa . types . is_decimal ( pa_dtype ) )
1011
1108
and not pa_version_under7p0
1012
1109
):
1013
1110
mark = pytest .mark .xfail (
1014
1111
raises = pa .ArrowInvalid ,
1015
1112
reason = "divide by 0" ,
1016
1113
)
1114
+ elif (
1115
+ opname == "__rtruediv__"
1116
+ and pa .types .is_decimal (pa_dtype )
1117
+ and not pa_version_under7p0
1118
+ ):
1119
+ mark = pytest .mark .xfail (
1120
+ raises = pa .ArrowInvalid ,
1121
+ reason = "divide by 0" ,
1122
+ )
1123
+ elif (
1124
+ opname == "__pow__"
1125
+ and pa .types .is_decimal (pa_dtype )
1126
+ and pa_version_under7p0
1127
+ ):
1128
+ mark = pytest .mark .xfail (
1129
+ raises = pa .ArrowInvalid ,
1130
+ reason = "Invalid decimal function: power_checked" ,
1131
+ )
1017
1132
1018
1133
return mark
1019
1134
@@ -1231,6 +1346,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
1231
1346
expected = ArrowDtype (pa .timestamp ("s" , "UTC" ))
1232
1347
assert dtype == expected
1233
1348
1349
+ with pytest .raises (NotImplementedError , match = "Passing pyarrow type" ):
1350
+ ArrowDtype .construct_from_string ("decimal(7, 2)[pyarrow]" )
1351
+
1234
1352
1235
1353
@pytest .mark .parametrize (
1236
1354
"interpolation" , ["linear" , "lower" , "higher" , "nearest" , "midpoint" ]
@@ -1257,7 +1375,11 @@ def test_quantile(data, interpolation, quantile, request):
1257
1375
ser .quantile (q = quantile , interpolation = interpolation )
1258
1376
return
1259
1377
1260
- if pa .types .is_integer (pa_dtype ) or pa .types .is_floating (pa_dtype ):
1378
+ if (
1379
+ pa .types .is_integer (pa_dtype )
1380
+ or pa .types .is_floating (pa_dtype )
1381
+ or (pa .types .is_decimal (pa_dtype ) and not pa_version_under7p0 )
1382
+ ):
1261
1383
pass
1262
1384
elif pa .types .is_temporal (data ._data .type ):
1263
1385
pass
@@ -1298,7 +1420,11 @@ def test_quantile(data, interpolation, quantile, request):
1298
1420
else :
1299
1421
# Just check the values
1300
1422
expected = pd .Series (data .take ([0 , 0 ]), index = [0.5 , 0.5 ])
1301
- if pa .types .is_integer (pa_dtype ) or pa .types .is_floating (pa_dtype ):
1423
+ if (
1424
+ pa .types .is_integer (pa_dtype )
1425
+ or pa .types .is_floating (pa_dtype )
1426
+ or pa .types .is_decimal (pa_dtype )
1427
+ ):
1302
1428
expected = expected .astype ("float64[pyarrow]" )
1303
1429
result = result .astype ("float64[pyarrow]" )
1304
1430
tm .assert_series_equal (result , expected )
0 commit comments