Skip to content

Commit 8ad5c12

Browse files
ganevgvjreback
authored andcommitted
TST: add test for .unique() dtype preserving (#29515)
1 parent 65a4ee6 commit 8ad5c12

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

pandas/tests/test_algos.py

+37
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
from pandas.compat.numpy import np_array_datetime64_compat
1111
import pandas.util._test_decorators as td
1212

13+
from pandas.core.dtypes.common import (
14+
is_bool_dtype,
15+
is_complex_dtype,
16+
is_float_dtype,
17+
is_integer_dtype,
18+
is_object_dtype,
19+
)
1320
from pandas.core.dtypes.dtypes import CategoricalDtype as CDT
1421

1522
import pandas as pd
@@ -23,6 +30,7 @@
2330
Timestamp,
2431
compat,
2532
)
33+
from pandas.conftest import BYTES_DTYPES, STRING_DTYPES
2634
import pandas.core.algorithms as algos
2735
from pandas.core.arrays import DatetimeArray
2836
import pandas.core.common as com
@@ -352,6 +360,35 @@ def test_on_index_object(self):
352360

353361
tm.assert_almost_equal(result, expected)
354362

363+
def test_dtype_preservation(self, any_numpy_dtype):
364+
# GH 15442
365+
if any_numpy_dtype in (BYTES_DTYPES + STRING_DTYPES):
366+
pytest.skip("skip string dtype")
367+
elif is_integer_dtype(any_numpy_dtype):
368+
data = [1, 2, 2]
369+
uniques = [1, 2]
370+
elif is_float_dtype(any_numpy_dtype):
371+
data = [1, 2, 2]
372+
uniques = [1.0, 2.0]
373+
elif is_complex_dtype(any_numpy_dtype):
374+
data = [complex(1, 0), complex(2, 0), complex(2, 0)]
375+
uniques = [complex(1, 0), complex(2, 0)]
376+
elif is_bool_dtype(any_numpy_dtype):
377+
data = [True, True, False]
378+
uniques = [True, False]
379+
elif is_object_dtype(any_numpy_dtype):
380+
data = ["A", "B", "B"]
381+
uniques = ["A", "B"]
382+
else:
383+
# datetime64[ns]/M8[ns]/timedelta64[ns]/m8[ns] tested elsewhere
384+
data = [1, 2, 2]
385+
uniques = [1, 2]
386+
387+
result = Series(data, dtype=any_numpy_dtype).unique()
388+
expected = np.array(uniques, dtype=any_numpy_dtype)
389+
390+
tm.assert_numpy_array_equal(result, expected)
391+
355392
def test_datetime64_dtype_array_returned(self):
356393
# GH 9431
357394
expected = np_array_datetime64_compat(

0 commit comments

Comments
 (0)