Skip to content

Commit 73ab162

Browse files
gfyoungjreback
authored andcommitted
TST: Use fixtures in dtypes/test_cast.py (#21661)
1 parent a620e72 commit 73ab162

File tree

2 files changed

+138
-123
lines changed

2 files changed

+138
-123
lines changed

pandas/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,18 @@ def float_dtype(request):
245245
return request.param
246246

247247

248+
@pytest.fixture(params=["complex64", "complex128"])
249+
def complex_dtype(request):
250+
"""
251+
Parameterized fixture for complex dtypes.
252+
253+
* complex64
254+
* complex128
255+
"""
256+
257+
return request.param
258+
259+
248260
UNSIGNED_INT_DTYPES = ["uint8", "uint16", "uint32", "uint64"]
249261
SIGNED_INT_DTYPES = ["int8", "int16", "int32", "int64"]
250262
ALL_INT_DTYPES = UNSIGNED_INT_DTYPES + SIGNED_INT_DTYPES

pandas/tests/dtypes/test_cast.py

+126-123
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
class TestMaybeDowncast(object):
3838

39-
def test_downcast_conv(self):
39+
def test_downcast(self):
4040
# test downcasting
4141

4242
arr = np.array([8.5, 8.6, 8.7, 8.8, 8.9999999999995])
@@ -53,33 +53,34 @@ def test_downcast_conv(self):
5353
expected = np.array([8, 8, 8, 8, 9], dtype=np.int64)
5454
tm.assert_numpy_array_equal(result, expected)
5555

56-
# GH16875 coercing of bools
56+
# see gh-16875: coercing of booleans.
5757
ser = Series([True, True, False])
5858
result = maybe_downcast_to_dtype(ser, np.dtype(np.float64))
5959
expected = ser
6060
tm.assert_series_equal(result, expected)
6161

62-
# conversions
63-
62+
@pytest.mark.parametrize("dtype", [np.float64, object, np.int64])
63+
def test_downcast_conversion_no_nan(self, dtype):
6464
expected = np.array([1, 2])
65-
for dtype in [np.float64, object, np.int64]:
66-
arr = np.array([1.0, 2.0], dtype=dtype)
67-
result = maybe_downcast_to_dtype(arr, 'infer')
68-
tm.assert_almost_equal(result, expected, check_dtype=False)
69-
70-
for dtype in [np.float64, object]:
71-
expected = np.array([1.0, 2.0, np.nan], dtype=dtype)
72-
arr = np.array([1.0, 2.0, np.nan], dtype=dtype)
73-
result = maybe_downcast_to_dtype(arr, 'infer')
74-
tm.assert_almost_equal(result, expected)
75-
76-
# empties
77-
for dtype in [np.int32, np.float64, np.float32, np.bool_,
78-
np.int64, object]:
79-
arr = np.array([], dtype=dtype)
80-
result = maybe_downcast_to_dtype(arr, 'int64')
81-
tm.assert_almost_equal(result, np.array([], dtype=np.int64))
82-
assert result.dtype == np.int64
65+
arr = np.array([1.0, 2.0], dtype=dtype)
66+
67+
result = maybe_downcast_to_dtype(arr, "infer")
68+
tm.assert_almost_equal(result, expected, check_dtype=False)
69+
70+
@pytest.mark.parametrize("dtype", [np.float64, object])
71+
def test_downcast_conversion_nan(self, dtype):
72+
expected = np.array([1.0, 2.0, np.nan], dtype=dtype)
73+
arr = np.array([1.0, 2.0, np.nan], dtype=dtype)
74+
75+
result = maybe_downcast_to_dtype(arr, "infer")
76+
tm.assert_almost_equal(result, expected)
77+
78+
@pytest.mark.parametrize("dtype", [np.int32, np.float64, np.float32,
79+
np.bool_, np.int64, object])
80+
def test_downcast_conversion_empty(self, dtype):
81+
arr = np.array([], dtype=dtype)
82+
result = maybe_downcast_to_dtype(arr, "int64")
83+
tm.assert_numpy_array_equal(result, np.array([], dtype=np.int64))
8384

8485
def test_datetimelikes_nan(self):
8586
arr = np.array([1, 2, np.nan])
@@ -104,66 +105,71 @@ def test_datetime_with_timezone(self):
104105

105106
class TestInferDtype(object):
106107

107-
def testinfer_dtype_from_scalar(self):
108-
# Test that infer_dtype_from_scalar is returning correct dtype for int
109-
# and float.
108+
def test_infer_dtype_from_int_scalar(self, any_int_dtype):
109+
# Test that infer_dtype_from_scalar is
110+
# returning correct dtype for int and float.
111+
data = np.dtype(any_int_dtype).type(12)
112+
dtype, val = infer_dtype_from_scalar(data)
113+
assert dtype == type(data)
114+
115+
def test_infer_dtype_from_float_scalar(self, float_dtype):
116+
float_dtype = np.dtype(float_dtype).type
117+
data = float_dtype(12)
110118

111-
for dtypec in [np.uint8, np.int8, np.uint16, np.int16, np.uint32,
112-
np.int32, np.uint64, np.int64]:
113-
data = dtypec(12)
114-
dtype, val = infer_dtype_from_scalar(data)
115-
assert dtype == type(data)
119+
dtype, val = infer_dtype_from_scalar(data)
120+
assert dtype == float_dtype
116121

122+
def test_infer_dtype_from_python_scalar(self):
117123
data = 12
118124
dtype, val = infer_dtype_from_scalar(data)
119125
assert dtype == np.int64
120126

121-
for dtypec in [np.float16, np.float32, np.float64]:
122-
data = dtypec(12)
123-
dtype, val = infer_dtype_from_scalar(data)
124-
assert dtype == dtypec
125-
126127
data = np.float(12)
127128
dtype, val = infer_dtype_from_scalar(data)
128129
assert dtype == np.float64
129130

130-
for data in [True, False]:
131-
dtype, val = infer_dtype_from_scalar(data)
132-
assert dtype == np.bool_
131+
@pytest.mark.parametrize("bool_val", [True, False])
132+
def test_infer_dtype_from_boolean(self, bool_val):
133+
dtype, val = infer_dtype_from_scalar(bool_val)
134+
assert dtype == np.bool_
133135

134-
for data in [np.complex64(1), np.complex128(1)]:
135-
dtype, val = infer_dtype_from_scalar(data)
136-
assert dtype == np.complex_
136+
def test_infer_dtype_from_complex(self, complex_dtype):
137+
data = np.dtype(complex_dtype).type(1)
138+
dtype, val = infer_dtype_from_scalar(data)
139+
assert dtype == np.complex_
137140

138-
for data in [np.datetime64(1, 'ns'), Timestamp(1),
139-
datetime(2000, 1, 1, 0, 0)]:
140-
dtype, val = infer_dtype_from_scalar(data)
141-
assert dtype == 'M8[ns]'
141+
@pytest.mark.parametrize("data", [np.datetime64(1, "ns"), Timestamp(1),
142+
datetime(2000, 1, 1, 0, 0)])
143+
def test_infer_dtype_from_datetime(self, data):
144+
dtype, val = infer_dtype_from_scalar(data)
145+
assert dtype == "M8[ns]"
142146

143-
for data in [np.timedelta64(1, 'ns'), Timedelta(1),
144-
timedelta(1)]:
145-
dtype, val = infer_dtype_from_scalar(data)
146-
assert dtype == 'm8[ns]'
147+
@pytest.mark.parametrize("data", [np.timedelta64(1, "ns"), Timedelta(1),
148+
timedelta(1)])
149+
def test_infer_dtype_from_timedelta(self, data):
150+
dtype, val = infer_dtype_from_scalar(data)
151+
assert dtype == "m8[ns]"
147152

148-
for freq in ['M', 'D']:
149-
p = Period('2011-01-01', freq=freq)
150-
dtype, val = infer_dtype_from_scalar(p, pandas_dtype=True)
151-
assert dtype == 'period[{0}]'.format(freq)
152-
assert val == p.ordinal
153+
@pytest.mark.parametrize("freq", ["M", "D"])
154+
def test_infer_dtype_from_period(self, freq):
155+
p = Period("2011-01-01", freq=freq)
156+
dtype, val = infer_dtype_from_scalar(p, pandas_dtype=True)
153157

154-
dtype, val = infer_dtype_from_scalar(p)
155-
dtype == np.object_
156-
assert val == p
158+
assert dtype == "period[{0}]".format(freq)
159+
assert val == p.ordinal
157160

158-
# misc
159-
for data in [date(2000, 1, 1),
160-
Timestamp(1, tz='US/Eastern'), 'foo']:
161+
dtype, val = infer_dtype_from_scalar(p)
162+
assert dtype == np.object_
163+
assert val == p
161164

162-
dtype, val = infer_dtype_from_scalar(data)
163-
assert dtype == np.object_
165+
@pytest.mark.parametrize("data", [date(2000, 1, 1), "foo",
166+
Timestamp(1, tz="US/Eastern")])
167+
def test_infer_dtype_misc(self, data):
168+
dtype, val = infer_dtype_from_scalar(data)
169+
assert dtype == np.object_
164170

165171
@pytest.mark.parametrize('tz', ['UTC', 'US/Eastern', 'Asia/Tokyo'])
166-
def testinfer_from_scalar_tz(self, tz):
172+
def test_infer_from_scalar_tz(self, tz):
167173
dt = Timestamp(1, tz=tz)
168174
dtype, val = infer_dtype_from_scalar(dt, pandas_dtype=True)
169175
assert dtype == 'datetime64[ns, {0}]'.format(tz)
@@ -173,7 +179,7 @@ def testinfer_from_scalar_tz(self, tz):
173179
assert dtype == np.object_
174180
assert val == dt
175181

176-
def testinfer_dtype_from_scalar_errors(self):
182+
def test_infer_dtype_from_scalar_errors(self):
177183
with pytest.raises(ValueError):
178184
infer_dtype_from_scalar(np.array([1]))
179185

@@ -329,66 +335,63 @@ def test_maybe_convert_objects_copy(self):
329335

330336
class TestCommonTypes(object):
331337

332-
def test_numpy_dtypes(self):
333-
# (source_types, destination_type)
334-
testcases = (
335-
# identity
336-
((np.int64,), np.int64),
337-
((np.uint64,), np.uint64),
338-
((np.float32,), np.float32),
339-
((np.object,), np.object),
340-
341-
# into ints
342-
((np.int16, np.int64), np.int64),
343-
((np.int32, np.uint32), np.int64),
344-
((np.uint16, np.uint64), np.uint64),
345-
346-
# into floats
347-
((np.float16, np.float32), np.float32),
348-
((np.float16, np.int16), np.float32),
349-
((np.float32, np.int16), np.float32),
350-
((np.uint64, np.int64), np.float64),
351-
((np.int16, np.float64), np.float64),
352-
((np.float16, np.int64), np.float64),
353-
354-
# into others
355-
((np.complex128, np.int32), np.complex128),
356-
((np.object, np.float32), np.object),
357-
((np.object, np.int16), np.object),
358-
359-
# bool with int
360-
((np.dtype('bool'), np.int64), np.object),
361-
((np.dtype('bool'), np.int32), np.object),
362-
((np.dtype('bool'), np.int16), np.object),
363-
((np.dtype('bool'), np.int8), np.object),
364-
((np.dtype('bool'), np.uint64), np.object),
365-
((np.dtype('bool'), np.uint32), np.object),
366-
((np.dtype('bool'), np.uint16), np.object),
367-
((np.dtype('bool'), np.uint8), np.object),
368-
369-
# bool with float
370-
((np.dtype('bool'), np.float64), np.object),
371-
((np.dtype('bool'), np.float32), np.object),
372-
373-
((np.dtype('datetime64[ns]'), np.dtype('datetime64[ns]')),
374-
np.dtype('datetime64[ns]')),
375-
((np.dtype('timedelta64[ns]'), np.dtype('timedelta64[ns]')),
376-
np.dtype('timedelta64[ns]')),
377-
378-
((np.dtype('datetime64[ns]'), np.dtype('datetime64[ms]')),
379-
np.dtype('datetime64[ns]')),
380-
((np.dtype('timedelta64[ms]'), np.dtype('timedelta64[ns]')),
381-
np.dtype('timedelta64[ns]')),
382-
383-
((np.dtype('datetime64[ns]'), np.dtype('timedelta64[ns]')),
384-
np.object),
385-
((np.dtype('datetime64[ns]'), np.int64), np.object)
386-
)
387-
for src, common in testcases:
388-
assert find_common_type(src) == common
389-
338+
@pytest.mark.parametrize("source_dtypes,expected_common_dtype", [
339+
((np.int64,), np.int64),
340+
((np.uint64,), np.uint64),
341+
((np.float32,), np.float32),
342+
((np.object,), np.object),
343+
344+
# into ints
345+
((np.int16, np.int64), np.int64),
346+
((np.int32, np.uint32), np.int64),
347+
((np.uint16, np.uint64), np.uint64),
348+
349+
# into floats
350+
((np.float16, np.float32), np.float32),
351+
((np.float16, np.int16), np.float32),
352+
((np.float32, np.int16), np.float32),
353+
((np.uint64, np.int64), np.float64),
354+
((np.int16, np.float64), np.float64),
355+
((np.float16, np.int64), np.float64),
356+
357+
# into others
358+
((np.complex128, np.int32), np.complex128),
359+
((np.object, np.float32), np.object),
360+
((np.object, np.int16), np.object),
361+
362+
# bool with int
363+
((np.dtype('bool'), np.int64), np.object),
364+
((np.dtype('bool'), np.int32), np.object),
365+
((np.dtype('bool'), np.int16), np.object),
366+
((np.dtype('bool'), np.int8), np.object),
367+
((np.dtype('bool'), np.uint64), np.object),
368+
((np.dtype('bool'), np.uint32), np.object),
369+
((np.dtype('bool'), np.uint16), np.object),
370+
((np.dtype('bool'), np.uint8), np.object),
371+
372+
# bool with float
373+
((np.dtype('bool'), np.float64), np.object),
374+
((np.dtype('bool'), np.float32), np.object),
375+
376+
((np.dtype('datetime64[ns]'), np.dtype('datetime64[ns]')),
377+
np.dtype('datetime64[ns]')),
378+
((np.dtype('timedelta64[ns]'), np.dtype('timedelta64[ns]')),
379+
np.dtype('timedelta64[ns]')),
380+
381+
((np.dtype('datetime64[ns]'), np.dtype('datetime64[ms]')),
382+
np.dtype('datetime64[ns]')),
383+
((np.dtype('timedelta64[ms]'), np.dtype('timedelta64[ns]')),
384+
np.dtype('timedelta64[ns]')),
385+
386+
((np.dtype('datetime64[ns]'), np.dtype('timedelta64[ns]')),
387+
np.object),
388+
((np.dtype('datetime64[ns]'), np.int64), np.object)
389+
])
390+
def test_numpy_dtypes(self, source_dtypes, expected_common_dtype):
391+
assert find_common_type(source_dtypes) == expected_common_dtype
392+
393+
def test_raises_empty_input(self):
390394
with pytest.raises(ValueError):
391-
# empty
392395
find_common_type([])
393396

394397
def test_categorical_dtype(self):

0 commit comments

Comments
 (0)