Skip to content

Commit 73805ce

Browse files
committed
CLN: add infer_dtype_from_array
1 parent 45e67e4 commit 73805ce

File tree

5 files changed

+109
-40
lines changed

5 files changed

+109
-40
lines changed

doc/source/whatsnew/v0.20.0.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,8 @@ Bug Fixes
883883

884884
- Bug in the display of ``.info()`` where a qualifier (+) would always be displayed with a ``MultiIndex`` that contains only non-strings (:issue:`15245`)
885885
- Bug in ``.replace()`` may result in incorrect dtypes. (:issue:`12747`, :issue:`15765`)
886+
- Bug in ``Series.replace`` and ``DataFrame.replace`` which failed on empty replacement dicts (:issue:`15289`)
887+
- Bug in ``Series.replace`` which replaced a numeric by string (:issue:`15743`)
886888

887889
- Bug in ``.asfreq()``, where frequency was not set for empty ``Series`` (:issue:`14320`)
888890

@@ -985,9 +987,6 @@ Bug Fixes
985987

986988
- Bug in ``DataFrame.hist`` where ``plt.tight_layout`` caused an ``AttributeError`` (use ``matplotlib >= 2.0.1``) (:issue:`9351`)
987989
- Bug in ``DataFrame.boxplot`` where ``fontsize`` was not applied to the tick labels on both axes (:issue:`15108`)
988-
- Bug in ``Series.replace`` and ``DataFrame.replace`` which failed on empty replacement dicts (:issue:`15289`)
989990
- Bug in ``pd.melt()`` where passing a tuple value for ``value_vars`` caused a ``TypeError`` (:issue:`15348`)
990991
- Bug in ``.eval()`` which caused multiline evals to fail with local variables not on the first line (:issue:`15342`)
991992
- Bug in ``pd.read_msgpack`` which did not allow to load dataframe with an index of type ``CategoricalIndex`` (:issue:`15487`)
992-
993-
- Bug in ``Series.replace`` which replaced a numeric by string (:issue:`15743`)

pandas/core/missing.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99

1010
from pandas.compat import range, string_types
1111
from pandas.types.common import (is_numeric_v_string_like,
12-
is_float_dtype, is_datetime64_dtype,
13-
is_datetime64tz_dtype, is_integer_dtype,
14-
_ensure_float64, is_scalar,
15-
needs_i8_conversion, is_integer)
12+
is_float_dtype,
13+
is_datetime64_dtype,
14+
is_datetime64tz_dtype,
15+
is_integer_dtype,
16+
is_scalar,
17+
is_integer,
18+
needs_i8_conversion,
19+
_ensure_float64)
20+
21+
from pandas.types.cast import infer_dtype_from_array
1622
from pandas.types.missing import isnull
1723

1824

@@ -21,16 +27,11 @@ def mask_missing(arr, values_to_mask):
2127
Return a masking array of same size/shape as arr
2228
with entries equaling any member of values_to_mask set to True
2329
"""
24-
if isinstance(values_to_mask, np.ndarray):
25-
mask_type = values_to_mask.dtype.type
26-
elif isinstance(values_to_mask, list):
27-
mask_type = type(values_to_mask[0])
28-
else:
29-
mask_type = type(values_to_mask)
30-
values_to_mask = [values_to_mask]
30+
31+
dtype, values_to_mask = infer_dtype_from_array(values_to_mask)
3132

3233
try:
33-
values_to_mask = np.array(values_to_mask, dtype=mask_type)
34+
values_to_mask = np.array(values_to_mask, dtype=dtype)
3435
except Exception:
3536
values_to_mask = np.array(values_to_mask, dtype=object)
3637

pandas/tests/frame/test_replace.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def test_replace_dtypes(self):
795795
expected = DataFrame({'datetime64': Index([now] * 3)})
796796
assert_frame_equal(result, expected)
797797

798-
def test_replace_input_formats(self):
798+
def test_replace_input_formats_listlike(self):
799799
# both dicts
800800
to_rep = {'A': np.nan, 'B': 0, 'C': ''}
801801
values = {'A': 0, 'B': -1, 'C': 'missing'}
@@ -812,15 +812,6 @@ def test_replace_input_formats(self):
812812
'C': ['', 'asdf', 'fd']})
813813
assert_frame_equal(result, expected)
814814

815-
# dict to scalar
816-
filled = df.replace(to_rep, 0)
817-
expected = {}
818-
for k, v in compat.iteritems(df):
819-
expected[k] = v.replace(to_rep[k], 0)
820-
assert_frame_equal(filled, DataFrame(expected))
821-
822-
self.assertRaises(TypeError, df.replace, to_rep, [np.nan, 0, ''])
823-
824815
# scalar to dict
825816
values = {'A': 0, 'B': -1, 'C': 'missing'}
826817
df = DataFrame({'A': [np.nan, 0, np.nan], 'B': [0, 2, 5],
@@ -842,6 +833,20 @@ def test_replace_input_formats(self):
842833

843834
self.assertRaises(ValueError, df.replace, to_rep, values[1:])
844835

836+
def test_replace_input_formats_scalar(self):
837+
df = DataFrame({'A': [np.nan, 0, np.inf], 'B': [0, 2, 5],
838+
'C': ['', 'asdf', 'fd']})
839+
840+
# dict to scalar
841+
to_rep = {'A': np.nan, 'B': 0, 'C': ''}
842+
filled = df.replace(to_rep, 0)
843+
expected = {}
844+
for k, v in compat.iteritems(df):
845+
expected[k] = v.replace(to_rep[k], 0)
846+
assert_frame_equal(filled, DataFrame(expected))
847+
848+
self.assertRaises(TypeError, df.replace, to_rep, [np.nan, 0, ''])
849+
845850
# list to scalar
846851
to_rep = [np.nan, 0, '']
847852
result = df.replace(to_rep, -1)

pandas/tests/types/test_cast.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
66
"""
77

8-
from datetime import datetime
8+
import pytest
9+
from datetime import datetime, timedelta, date
910
import numpy as np
1011

1112
from pandas import Timedelta, Timestamp, DatetimeIndex
1213
from pandas.types.cast import (maybe_downcast_to_dtype,
1314
maybe_convert_objects,
1415
infer_dtype_from_scalar,
16+
infer_dtype_from_array,
1517
maybe_convert_string_to_object,
1618
maybe_convert_scalar,
1719
find_common_type)
@@ -82,7 +84,7 @@ def test_datetime_with_timezone(self):
8284
tm.assert_index_equal(res, exp)
8385

8486

85-
class TestInferDtype(tm.TestCase):
87+
class TestInferDtype(object):
8688

8789
def test_infer_dtype_from_scalar(self):
8890
# Test that _infer_dtype_from_scalar is returning correct dtype for int
@@ -92,44 +94,62 @@ def test_infer_dtype_from_scalar(self):
9294
np.int32, np.uint64, np.int64]:
9395
data = dtypec(12)
9496
dtype, val = infer_dtype_from_scalar(data)
95-
self.assertEqual(dtype, type(data))
97+
assert dtype == type(data)
9698

9799
data = 12
98100
dtype, val = infer_dtype_from_scalar(data)
99-
self.assertEqual(dtype, np.int64)
101+
assert dtype == np.int64
100102

101103
for dtypec in [np.float16, np.float32, np.float64]:
102104
data = dtypec(12)
103105
dtype, val = infer_dtype_from_scalar(data)
104-
self.assertEqual(dtype, dtypec)
106+
assert dtype == dtypec
105107

106108
data = np.float(12)
107109
dtype, val = infer_dtype_from_scalar(data)
108-
self.assertEqual(dtype, np.float64)
110+
assert dtype == np.float64
109111

110112
for data in [True, False]:
111113
dtype, val = infer_dtype_from_scalar(data)
112-
self.assertEqual(dtype, np.bool_)
114+
assert dtype == np.bool_
113115

114116
for data in [np.complex64(1), np.complex128(1)]:
115117
dtype, val = infer_dtype_from_scalar(data)
116-
self.assertEqual(dtype, np.complex_)
118+
assert dtype == np.complex_
117119

118-
import datetime
119120
for data in [np.datetime64(1, 'ns'), Timestamp(1),
120-
datetime.datetime(2000, 1, 1, 0, 0)]:
121+
datetime(2000, 1, 1, 0, 0)]:
121122
dtype, val = infer_dtype_from_scalar(data)
122-
self.assertEqual(dtype, 'M8[ns]')
123+
assert dtype == 'M8[ns]'
123124

124125
for data in [np.timedelta64(1, 'ns'), Timedelta(1),
125-
datetime.timedelta(1)]:
126+
timedelta(1)]:
126127
dtype, val = infer_dtype_from_scalar(data)
127-
self.assertEqual(dtype, 'm8[ns]')
128+
assert dtype == 'm8[ns]'
128129

129-
for data in [datetime.date(2000, 1, 1),
130+
for data in [date(2000, 1, 1),
130131
Timestamp(1, tz='US/Eastern'), 'foo']:
131132
dtype, val = infer_dtype_from_scalar(data)
132-
self.assertEqual(dtype, np.object_)
133+
assert dtype == np.object_
134+
135+
@pytest.mark.parametrize(
136+
"arr, expected",
137+
[('foo', np.object_),
138+
(b'foo', np.object_),
139+
(1, np.int_),
140+
(1.5, np.float_),
141+
([1], np.int_),
142+
(np.array([1]), np.int_),
143+
([np.nan, 1, ''], np.object_),
144+
(np.array([[1.0, 2.0]]), np.float_),
145+
(Timestamp('20160101'), np.object_),
146+
(np.datetime64('2016-01-01'), np.dtype('<M8[D]')),
147+
])
148+
def test_infer_dtype_from_array(self, arr, expected):
149+
150+
# these infer specifically to numpy dtypes
151+
dtype, _ = infer_dtype_from_array(arr)
152+
assert dtype == expected
133153

134154

135155
class TestMaybe(tm.TestCase):

pandas/types/cast.py

+44
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,50 @@ def infer_dtype_from_scalar(val, pandas_dtype=False):
387387
return dtype, val
388388

389389

390+
def infer_dtype_from_array(arr):
391+
"""
392+
infer the dtype from a scalar or array
393+
394+
Parameters
395+
----------
396+
arr : scalar or array
397+
398+
Returns
399+
-------
400+
tuple (numpy-compat dtype, array)
401+
402+
Notes
403+
-----
404+
These infer to numpy dtypes exactly
405+
with the exception that mixed / object dtypes
406+
are not coerced by stringifying or conversion
407+
408+
Examples
409+
--------
410+
>>> np.asarray([1, '1'])
411+
array(['1', '1'], dtype='<U21')
412+
413+
>>> infer_dtype_from_array([1, '1'])
414+
(numpy.object_, [1, '1'])
415+
416+
"""
417+
418+
if isinstance(arr, np.ndarray):
419+
return arr.dtype, arr
420+
421+
if not is_list_like(arr):
422+
arr = [arr]
423+
424+
# don't force numpy coerce with nan's
425+
inferred = lib.infer_dtype(arr)
426+
if inferred in ['string', 'bytes', 'unicode',
427+
'mixed', 'mixed-integer']:
428+
return (np.object_, arr)
429+
430+
arr = np.asarray(arr)
431+
return arr.dtype, arr
432+
433+
390434
def maybe_upcast(values, fill_value=np.nan, dtype=None, copy=False):
391435
""" provide explict type promotion and coercion
392436

0 commit comments

Comments
 (0)