Skip to content

Commit 7ecfa8e

Browse files
jorisvandenbosschejreback
authored andcommitted
TST: test custom _formatter for ExtensionArray + revert ExtensionArrayFormatter removal (#26845)
* TST: test custom _formatter for ExtensionArray * Revert "REF: remove ExtensionArrayFormatter (#26833)" This reverts commit a00659a.
1 parent 137a886 commit 7ecfa8e

File tree

4 files changed

+43
-16
lines changed

4 files changed

+43
-16
lines changed

pandas/core/indexes/interval.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1061,9 +1061,11 @@ def _format_with_header(self, header, **kwargs):
10611061

10621062
def _format_native_types(self, na_rep='NaN', quoting=None, **kwargs):
10631063
""" actually format my specific types """
1064-
from pandas.io.formats.format import format_array
1065-
return format_array(values=self, na_rep=na_rep, justify='all',
1066-
leading_space=False)
1064+
from pandas.io.formats.format import ExtensionArrayFormatter
1065+
return ExtensionArrayFormatter(values=self,
1066+
na_rep=na_rep,
1067+
justify='all',
1068+
leading_space=False).get_result()
10671069

10681070
def _format_data(self, name=None):
10691071

pandas/io/formats/format.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def _get_column_name_list(self):
849849
# Array formatters
850850

851851

852-
def format_array(values, formatter=None, float_format=None, na_rep='NaN',
852+
def format_array(values, formatter, float_format=None, na_rep='NaN',
853853
digits=None, space=None, justify='right', decimal='.',
854854
leading_space=None):
855855
"""
@@ -879,23 +879,14 @@ def format_array(values, formatter=None, float_format=None, na_rep='NaN',
879879
List[str]
880880
"""
881881

882-
if is_extension_array_dtype(values.dtype):
883-
if isinstance(values, (ABCIndexClass, ABCSeries)):
884-
values = values._values
885-
886-
if is_categorical_dtype(values.dtype):
887-
# Categorical is special for now, so that we can preserve tzinfo
888-
values = values.get_values()
889-
890-
if not is_datetime64tz_dtype(values.dtype):
891-
values = np.asarray(values)
892-
893882
if is_datetime64_dtype(values.dtype):
894883
fmt_klass = Datetime64Formatter
895884
elif is_datetime64tz_dtype(values):
896885
fmt_klass = Datetime64TZFormatter
897886
elif is_timedelta64_dtype(values.dtype):
898887
fmt_klass = Timedelta64Formatter
888+
elif is_extension_array_dtype(values.dtype):
889+
fmt_klass = ExtensionArrayFormatter
899890
elif is_float_dtype(values.dtype) or is_complex_dtype(values.dtype):
900891
fmt_klass = FloatArrayFormatter
901892
elif is_integer_dtype(values.dtype):
@@ -1190,6 +1181,29 @@ def _format_strings(self):
11901181
return fmt_values.tolist()
11911182

11921183

1184+
class ExtensionArrayFormatter(GenericArrayFormatter):
1185+
def _format_strings(self):
1186+
values = self.values
1187+
if isinstance(values, (ABCIndexClass, ABCSeries)):
1188+
values = values._values
1189+
1190+
formatter = values._formatter(boxed=True)
1191+
1192+
if is_categorical_dtype(values.dtype):
1193+
# Categorical is special for now, so that we can preserve tzinfo
1194+
array = values.get_values()
1195+
else:
1196+
array = np.asarray(values)
1197+
1198+
fmt_values = format_array(array,
1199+
formatter,
1200+
float_format=self.float_format,
1201+
na_rep=self.na_rep, digits=self.digits,
1202+
space=self.space, justify=self.justify,
1203+
leading_space=self.leading_space)
1204+
return fmt_values
1205+
1206+
11931207
def format_percentiles(percentiles):
11941208
"""
11951209
Outputs rounded and formatted percentiles.

pandas/tests/extension/decimal/array.py

+5
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def isna(self):
137137
def _na_value(self):
138138
return decimal.Decimal('NaN')
139139

140+
def _formatter(self, boxed=False):
141+
if boxed:
142+
return "Decimal: {0}".format
143+
return repr
144+
140145
@classmethod
141146
def _concat_same_type(cls, to_concat):
142147
return cls(np.concatenate([x._data for x in to_concat]))

pandas/tests/extension/decimal/test_decimal.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,13 @@ class TestSetitem(BaseDecimal, base.BaseSetitemTests):
200200

201201

202202
class TestPrinting(BaseDecimal, base.BasePrintingTests):
203-
pass
203+
204+
def test_series_repr(self, data):
205+
# Overriding this base test to explicitly test that
206+
# the custom _formatter is used
207+
ser = pd.Series(data)
208+
assert data.dtype.name in repr(ser)
209+
assert "Decimal: " in repr(ser)
204210

205211

206212
# TODO(extension)

0 commit comments

Comments
 (0)