Skip to content

Commit 24cdf45

Browse files
varunkumar-devjreback
authored andcommitted
BUG: #11638 return correct dtype for int and float
1 parent 784445c commit 24cdf45

File tree

4 files changed

+73
-5
lines changed

4 files changed

+73
-5
lines changed

doc/source/whatsnew/v0.17.1.txt

+1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ Bug Fixes
168168
- Bug in ``DataFrame.pct_change()`` not propagating ``axis`` keyword on ``.fillna`` method (:issue:`11150`)
169169
- Bug in ``.to_csv()`` when a mix of integer and string column names are passed as the ``columns`` parameter (:issue:`11637`)
170170
- Bug in indexing with a ``range``, (:issue:`11652`)
171+
- Bug in inference of numpy scalars and preserving dtype when setting columns (:issue:`11638`)
171172
- Bug in ``to_sql`` using unicode column names giving UnicodeEncodeError with (:issue:`11431`).
172173
- Fix regression in setting of ``xticks`` in ``plot`` (:issue:`11529`).
173174
- Bug in ``holiday.dates`` where observance rules could not be applied to holiday and doc enhancement (:issue:`11477`, :issue:`11533`)

pandas/core/common.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,7 @@ def _infer_fill_value(val):
10021002

10031003

10041004
def _infer_dtype_from_scalar(val):
1005-
""" interpret the dtype from a scalar, upcast floats and ints
1006-
return the new value and the dtype """
1005+
""" interpret the dtype from a scalar """
10071006

10081007
dtype = np.object_
10091008

@@ -1037,12 +1036,17 @@ def _infer_dtype_from_scalar(val):
10371036
elif is_bool(val):
10381037
dtype = np.bool_
10391038

1040-
# provide implicity upcast on scalars
10411039
elif is_integer(val):
1042-
dtype = np.int64
1040+
if isinstance(val, int):
1041+
dtype = np.int64
1042+
else:
1043+
dtype = type(val)
10431044

10441045
elif is_float(val):
1045-
dtype = np.float64
1046+
if isinstance(val, float):
1047+
dtype = np.float64
1048+
else:
1049+
dtype = type(val)
10461050

10471051
elif is_complex(val):
10481052
dtype = np.complex_

pandas/tests/test_common.py

+56
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,62 @@ def test_abc_types(self):
9898
self.assertIsInstance(pd.Period('2012', freq='A-DEC'), com.ABCPeriod)
9999

100100

101+
class TestInferDtype(tm.TestCase):
102+
103+
def test_infer_dtype_from_scalar(self):
104+
# Test that _infer_dtype_from_scalar is returning correct dtype for int and float.
105+
106+
for dtypec in [ np.uint8, np.int8,
107+
np.uint16, np.int16,
108+
np.uint32, np.int32,
109+
np.uint64, np.int64 ]:
110+
data = dtypec(12)
111+
dtype, val = com._infer_dtype_from_scalar(data)
112+
self.assertEqual(dtype, dtypec)
113+
114+
data = 12
115+
dtype, val = com._infer_dtype_from_scalar(data)
116+
self.assertEqual(dtype, np.int64)
117+
118+
for dtypec in [ np.float16, np.float32, np.float64 ]:
119+
data = dtypec(12)
120+
dtype, val = com._infer_dtype_from_scalar(data)
121+
self.assertEqual(dtype, dtypec)
122+
123+
data = np.float(12)
124+
dtype, val = com._infer_dtype_from_scalar(data)
125+
self.assertEqual(dtype, np.float64)
126+
127+
for data in [ True, False ]:
128+
dtype, val = com._infer_dtype_from_scalar(data)
129+
self.assertEqual(dtype, np.bool_)
130+
131+
for data in [ np.complex64(1), np.complex128(1) ]:
132+
dtype, val = com._infer_dtype_from_scalar(data)
133+
self.assertEqual(dtype, np.complex_)
134+
135+
import datetime
136+
for data in [ np.datetime64(1,'ns'),
137+
pd.Timestamp(1),
138+
datetime.datetime(2000,1,1,0,0)
139+
]:
140+
dtype, val = com._infer_dtype_from_scalar(data)
141+
self.assertEqual(dtype, 'M8[ns]')
142+
143+
for data in [ np.timedelta64(1,'ns'),
144+
pd.Timedelta(1),
145+
datetime.timedelta(1)
146+
]:
147+
dtype, val = com._infer_dtype_from_scalar(data)
148+
self.assertEqual(dtype, 'm8[ns]')
149+
150+
for data in [ datetime.date(2000,1,1),
151+
pd.Timestamp(1,tz='US/Eastern'),
152+
'foo'
153+
]:
154+
dtype, val = com._infer_dtype_from_scalar(data)
155+
self.assertEqual(dtype, np.object_)
156+
101157
def test_notnull():
102158
assert notnull(1.)
103159
assert not notnull(None)

pandas/tests/test_frame.py

+7
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,13 @@ def test_setitem_cast(self):
683683
expected = Series({'float64' : 3, 'object' : 1 }).sort_values()
684684
assert_series_equal(result, expected)
685685

686+
# Test that data type is preserved . #5782
687+
df = DataFrame({'one': np.arange(6, dtype=np.int8)})
688+
df.loc[1, 'one'] = 6
689+
self.assertEqual(df.dtypes.one, np.dtype(np.int8))
690+
df.one = np.int8(7)
691+
self.assertEqual(df.dtypes.one, np.dtype(np.int8))
692+
686693
def test_setitem_boolean_column(self):
687694
expected = self.frame.copy()
688695
mask = self.frame['A'] > 0

0 commit comments

Comments
 (0)