Skip to content

Commit 10aab3c

Browse files
committed
make construct_array_type arg optional
1 parent 1e8061c commit 10aab3c

File tree

8 files changed

+28
-12
lines changed

8 files changed

+28
-12
lines changed

pandas/core/dtypes/base.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,20 @@ def name(self):
162162
raise AbstractMethodError(self)
163163

164164
@classmethod
165-
def construct_array_type(cls, array):
165+
def construct_array_type(cls, array=None):
166166
"""Return the array type associated with this dtype
167167
168168
Parameters
169169
----------
170-
string : str
170+
array : array-like, optional
171171
172172
Returns
173173
-------
174174
type
175175
"""
176-
return type(array)
176+
if array is None:
177+
return cls
178+
raise NotImplementedError
177179

178180
@classmethod
179181
def construct_from_string(cls, string):

pandas/core/dtypes/dtypes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,12 @@ def _hash_categories(categories, ordered=True):
318318
return np.bitwise_xor.reduce(hashed)
319319

320320
@classmethod
321-
def construct_array_type(cls, array):
321+
def construct_array_type(cls, array=None):
322322
"""Return the array type associated with this dtype
323323
324324
Parameters
325325
----------
326-
array : value array
326+
array : array-like, optional
327327
328328
Returns
329329
-------

pandas/tests/extension/base/dtype.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import numpy as np
23
import pandas as pd
34

@@ -46,3 +47,10 @@ def test_eq_with_str(self, dtype):
4647

4748
def test_eq_with_numpy_object(self, dtype):
4849
assert dtype != np.dtype('object')
50+
51+
def test_array_type(self, data, dtype):
52+
assert dtype.construct_array_type() is type(data)
53+
54+
def test_array_type_with_arg(self, data, dtype):
55+
with pytest.raises(NotImplementedError):
56+
dtype.construct_array_type('foo')

pandas/tests/extension/category/test_categorical.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def data_for_grouping():
5252

5353

5454
class TestDtype(base.BaseDtypeTests):
55-
pass
55+
56+
def test_array_type_with_arg(self, data, dtype):
57+
assert dtype.construct_array_type('foo') is Categorical
5658

5759

5860
class TestInterface(base.BaseInterfaceTests):

pandas/tests/extension/decimal/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ class DecimalDtype(ExtensionDtype):
1616
na_value = decimal.Decimal('NaN')
1717

1818
@classmethod
19-
def construct_array_type(cls, array):
19+
def construct_array_type(cls, array=None):
2020
"""Return the array type associated with this dtype
2121
2222
Parameters
2323
----------
24-
string : str
24+
array : array-like, optional
2525
2626
Returns
2727
-------

pandas/tests/extension/decimal/test_decimal.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
9292

9393

9494
class TestDtype(BaseDecimal, base.BaseDtypeTests):
95-
pass
95+
96+
def test_array_type_with_arg(self, data, dtype):
97+
assert dtype.construct_array_type('foo') is DecimalArray
9698

9799

98100
class TestInterface(BaseDecimal, base.BaseInterfaceTests):

pandas/tests/extension/json/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class JSONDtype(ExtensionDtype):
3333
na_value = {}
3434

3535
@classmethod
36-
def construct_array_type(cls, array):
36+
def construct_array_type(cls, array=None):
3737
"""Return the array type associated with this dtype
3838
3939
Parameters
4040
----------
41-
string : str
41+
array : array-like, optional
4242
4343
Returns
4444
-------

pandas/tests/extension/json/test_json.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
107107

108108

109109
class TestDtype(BaseJSON, base.BaseDtypeTests):
110-
pass
110+
111+
def test_array_type_with_arg(self, data, dtype):
112+
assert dtype.construct_array_type('foo') is JSONArray
111113

112114

113115
class TestInterface(BaseJSON, base.BaseInterfaceTests):

0 commit comments

Comments
 (0)