Skip to content

Commit e5e597a

Browse files
stuartebergproost
authored andcommitted
COMPAT: unique() should preserve the dtype of the input (pandas-dev#27874)
1 parent db0976b commit e5e597a

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

doc/source/whatsnew/v1.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ Other API changes
151151

152152
- :meth:`pandas.api.types.infer_dtype` will now return "integer-na" for integer and ``np.nan`` mix (:issue:`27283`)
153153
- :meth:`MultiIndex.from_arrays` will no longer infer names from arrays if ``names=None`` is explicitly provided (:issue:`27292`)
154+
- The returned dtype of ::func:`pd.unique` now matches the input dtype. (:issue:`27874`)
154155
-
155156

156157
.. _whatsnew_1000.api.documentation:

pandas/core/algorithms.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ def _reconstruct_data(values, dtype, original):
180180
if is_extension_array_dtype(dtype):
181181
values = dtype.construct_array_type()._from_sequence(values)
182182
elif is_bool_dtype(dtype):
183-
values = values.astype(dtype)
183+
values = values.astype(dtype, copy=False)
184184

185185
# we only support object dtypes bool Index
186186
if isinstance(original, ABCIndexClass):
187-
values = values.astype(object)
187+
values = values.astype(object, copy=False)
188188
elif dtype is not None:
189-
values = values.astype(dtype)
189+
values = values.astype(dtype, copy=False)
190190

191191
return values
192192

@@ -396,7 +396,7 @@ def unique(values):
396396

397397
table = htable(len(values))
398398
uniques = table.unique(values)
399-
uniques = _reconstruct_data(uniques, dtype, original)
399+
uniques = _reconstruct_data(uniques, original.dtype, original)
400400
return uniques
401401

402402

pandas/tests/test_base.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def test_memory_usage(self):
159159
class Ops:
160160
def _allow_na_ops(self, obj):
161161
"""Whether to skip test cases including NaN"""
162-
if isinstance(obj, Index) and (obj.is_boolean() or not obj._can_hold_na):
163-
# don't test boolean / int64 index
162+
if (isinstance(obj, Index) and obj.is_boolean()) or not obj._can_hold_na:
163+
# don't test boolean / integer dtypes
164164
return False
165165
return True
166166

@@ -187,7 +187,24 @@ def setup_method(self, method):
187187
types = ["bool", "int", "float", "dt", "dt_tz", "period", "string", "unicode"]
188188
self.indexes = [getattr(self, "{}_index".format(t)) for t in types]
189189
self.series = [getattr(self, "{}_series".format(t)) for t in types]
190-
self.objs = self.indexes + self.series
190+
191+
# To test narrow dtypes, we use narrower *data* elements, not *index* elements
192+
index = self.int_index
193+
self.float32_series = Series(arr.astype(np.float32), index=index, name="a")
194+
195+
arr_int = np.random.choice(10, size=10, replace=False)
196+
self.int8_series = Series(arr_int.astype(np.int8), index=index, name="a")
197+
self.int16_series = Series(arr_int.astype(np.int16), index=index, name="a")
198+
self.int32_series = Series(arr_int.astype(np.int32), index=index, name="a")
199+
200+
self.uint8_series = Series(arr_int.astype(np.uint8), index=index, name="a")
201+
self.uint16_series = Series(arr_int.astype(np.uint16), index=index, name="a")
202+
self.uint32_series = Series(arr_int.astype(np.uint32), index=index, name="a")
203+
204+
nrw_types = ["float32", "int8", "int16", "int32", "uint8", "uint16", "uint32"]
205+
self.narrow_series = [getattr(self, "{}_series".format(t)) for t in nrw_types]
206+
207+
self.objs = self.indexes + self.series + self.narrow_series
191208

192209
def check_ops_properties(self, props, filter=None, ignore_failures=False):
193210
for op in props:
@@ -385,6 +402,7 @@ def test_value_counts_unique_nunique(self):
385402
if isinstance(o, Index):
386403
assert isinstance(result, o.__class__)
387404
tm.assert_index_equal(result, orig)
405+
assert result.dtype == orig.dtype
388406
elif is_datetime64tz_dtype(o):
389407
# datetimetz Series returns array of Timestamp
390408
assert result[0] == orig[0]
@@ -396,6 +414,7 @@ def test_value_counts_unique_nunique(self):
396414
)
397415
else:
398416
tm.assert_numpy_array_equal(result, orig.values)
417+
assert result.dtype == orig.dtype
399418

400419
assert o.nunique() == len(np.unique(o.values))
401420

@@ -904,7 +923,7 @@ def test_fillna(self):
904923

905924
expected = [fill_value] * 2 + list(values[2:])
906925

907-
expected = klass(expected)
926+
expected = klass(expected, dtype=orig.dtype)
908927
o = klass(values)
909928

910929
# check values has the same dtype as the original

0 commit comments

Comments
 (0)