@@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype):
15
15
name = 'decimal'
16
16
na_value = decimal .Decimal ('NaN' )
17
17
18
+ def __init__ (self , context = None ):
19
+ self .context = context or decimal .getcontext ()
20
+
21
+ def __eq__ (self , other ):
22
+ if isinstance (other , type (self )):
23
+ return self .context == other .context
24
+ return super (DecimalDtype , self ).__eq__ (other )
25
+
26
+ def __repr__ (self ):
27
+ return 'DecimalDtype(context={})' .format (self .context )
28
+
18
29
@classmethod
19
30
def construct_array_type (cls ):
20
31
"""Return the array type associated with this dtype
@@ -35,13 +46,12 @@ def construct_from_string(cls, string):
35
46
36
47
37
48
class DecimalArray (ExtensionArray , ExtensionScalarOpsMixin ):
38
- dtype = DecimalDtype ()
39
49
40
- def __init__ (self , values , dtype = None , copy = False ):
50
+ def __init__ (self , values , dtype = None , copy = False , context = None ):
41
51
for val in values :
42
- if not isinstance (val , self . dtype . type ):
52
+ if not isinstance (val , decimal . Decimal ):
43
53
raise TypeError ("All values must be of type " +
44
- str (self . dtype . type ))
54
+ str (decimal . Decimal ))
45
55
values = np .asarray (values , dtype = object )
46
56
47
57
self ._data = values
@@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False):
51
61
# those aliases are currently not working due to assumptions
52
62
# in internal code (GH-20735)
53
63
# self._values = self.values = self.data
64
+ self ._dtype = DecimalDtype (context )
65
+
66
+ @property
67
+ def dtype (self ):
68
+ return self ._dtype
54
69
55
70
@classmethod
56
71
def _from_sequence (cls , scalars , dtype = None , copy = False ):
@@ -82,6 +97,11 @@ def copy(self, deep=False):
82
97
return type (self )(self ._data .copy ())
83
98
return type (self )(self )
84
99
100
+ def astype (self , dtype , copy = True ):
101
+ if isinstance (dtype , type (self .dtype )):
102
+ return type (self )(self ._data , context = dtype .context )
103
+ return super (DecimalArray , self ).astype (dtype , copy )
104
+
85
105
def __setitem__ (self , key , value ):
86
106
if pd .api .types .is_list_like (value ):
87
107
value = [decimal .Decimal (v ) for v in value ]
0 commit comments