Skip to content

Commit 0dea22d

Browse files
committed
use any_numpy_dtype to extract dtypes
1 parent 6af0391 commit 0dea22d

File tree

1 file changed

+35
-16
lines changed

1 file changed

+35
-16
lines changed

pandas/tests/test_algos.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212

1313
from pandas.core.dtypes.dtypes import CategoricalDtype as CDT
1414

15+
from pandas.core.dtypes.common import (
16+
is_integer_dtype,
17+
is_float_dtype,
18+
is_complex_dtype,
19+
is_bool_dtype,
20+
is_object_dtype,
21+
)
22+
from pandas.conftest import BYTES_DTYPES, STRING_DTYPES
23+
1524
import pandas as pd
1625
from pandas import (
1726
Categorical,
@@ -353,24 +362,34 @@ def test_on_index_object(self):
353362

354363
tm.assert_almost_equal(result, expected)
355364

356-
@pytest.mark.parametrize(
357-
"data, uniques, dtype_list",
358-
[
359-
([1, 2, 2], [1, 2], np.sctypes["int"]),
360-
([1, 2, 2], [1, 2], np.sctypes["uint"]),
361-
([1, 2, 2], [1.0, 2.0], np.sctypes["float"]),
362-
([1, 2, 2], [1.0, 2.0], np.sctypes["complex"]),
363-
([True, True, False], [True, False], np.sctypes["others"]), # bool, object
364-
],
365-
)
366-
def test_dtype_preservation(self, data, uniques, dtype_list):
365+
def test_dtype_preservation(self, any_numpy_dtype):
367366
# GH 15442
368-
for dtype in dtype_list:
369-
if dtype not in [bytes, str, np.void]:
370-
result = Series(data, dtype=dtype).unique()
371-
expected = np.array(uniques, dtype=dtype)
367+
if any_numpy_dtype in (BYTES_DTYPES + STRING_DTYPES):
368+
pytest.skip("skip string dtype")
369+
elif is_integer_dtype(any_numpy_dtype):
370+
data = [1, 2, 2]
371+
uniques = [1, 2]
372+
elif is_float_dtype(any_numpy_dtype):
373+
data = [1, 2, 2]
374+
uniques = [1.0, 2.0]
375+
elif is_complex_dtype(any_numpy_dtype):
376+
data = [complex(1, 0), complex(2, 0), complex(2, 0)]
377+
uniques = [complex(1, 0), complex(2, 0)]
378+
elif is_bool_dtype(any_numpy_dtype):
379+
data = [True, True, False]
380+
uniques = [True, False]
381+
elif is_object_dtype(any_numpy_dtype):
382+
data = ["A", "B", "B"]
383+
uniques = ["A", "B"]
384+
else:
385+
# datetime64[ns]/M8[ns]/timedelta64[ns]/m8[ns] tested elsewhere
386+
data = [1, 2, 2]
387+
uniques = [1, 2]
372388

373-
tm.assert_numpy_array_equal(result, expected)
389+
result = Series(data, dtype=any_numpy_dtype).unique()
390+
expected = np.array(uniques, dtype=any_numpy_dtype)
391+
392+
tm.assert_numpy_array_equal(result, expected)
374393

375394
def test_datetime64_dtype_array_returned(self):
376395
# GH 9431

0 commit comments

Comments
 (0)