Skip to content

Commit 8bde21a

Browse files
sinhrksjreback
authored andcommitted
BUG: replace coerces incorrect dtype
closes #12747 Author: sinhrks <[email protected]> This patch had conflicts when merged, resolved by Committer: Jeff Reback <[email protected]> Closes #12780 from sinhrks/replace_type and squashes the following commits: f9154e8 [sinhrks] remove unnecessary comments 279fdf6 [sinhrks] remove import failure de44877 [sinhrks] BUG: replace coerces incorrect dtype
1 parent b1e29db commit 8bde21a

File tree

5 files changed

+88
-24
lines changed

5 files changed

+88
-24
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ Bug Fixes
823823

824824

825825
- Bug in the display of ``.info()`` where a qualifier (+) would always be displayed with a ``MultiIndex`` that contains only non-strings (:issue:`15245`)
826+
- Bug in ``.replace()`` may result in incorrect dtypes. (:issue:`12747`)
826827

827828
- Bug in ``.asfreq()``, where frequency was not set for empty ``Series`` (:issue:`14320`)
828829

pandas/core/internals.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1894,8 +1894,11 @@ def convert(self, *args, **kwargs):
18941894
blocks.append(newb)
18951895

18961896
else:
1897-
values = fn(
1898-
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
1897+
values = fn(self.values.ravel(), **fn_kwargs)
1898+
try:
1899+
values = values.reshape(self.values.shape)
1900+
except NotImplementedError:
1901+
pass
18991902
blocks.append(make_block(values, ndim=self.ndim,
19001903
placement=self.mgr_locs))
19011904

@@ -3238,6 +3241,16 @@ def comp(s):
32383241
return _possibly_compare(values, getattr(s, 'asm8', s),
32393242
operator.eq)
32403243

3244+
def _cast_scalar(block, scalar):
3245+
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
3246+
if not is_dtype_equal(block.dtype, dtype):
3247+
dtype = _find_common_type([block.dtype, dtype])
3248+
block = block.astype(dtype)
3249+
# use original value
3250+
val = scalar
3251+
3252+
return block, val
3253+
32413254
masks = [comp(s) for i, s in enumerate(src_list)]
32423255

32433256
result_blocks = []
@@ -3260,7 +3273,8 @@ def comp(s):
32603273
# particular block
32613274
m = masks[i][b.mgr_locs.indexer]
32623275
if m.any():
3263-
new_rb.extend(b.putmask(m, d, inplace=True))
3276+
b, val = _cast_scalar(b, d)
3277+
new_rb.extend(b.putmask(m, val, inplace=True))
32643278
else:
32653279
new_rb.append(b)
32663280
rb = new_rb

pandas/tests/indexing/test_coercion.py

+39-11
Original file line numberDiff line numberDiff line change
@@ -1153,12 +1153,27 @@ def setUp(self):
11531153
self.rep['float64'] = [1.1, 2.2]
11541154
self.rep['complex128'] = [1 + 1j, 2 + 2j]
11551155
self.rep['bool'] = [True, False]
1156+
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
1157+
pd.Timestamp('2011-01-03')]
1158+
1159+
for tz in ['UTC', 'US/Eastern']:
1160+
# to test tz => different tz replacement
1161+
key = 'datetime64[ns, {0}]'.format(tz)
1162+
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
1163+
pd.Timestamp('2011-01-03', tz=tz)]
1164+
1165+
self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
1166+
pd.Timedelta('2 day')]
11561167

11571168
def _assert_replace_conversion(self, from_key, to_key, how):
11581169
index = pd.Index([3, 4], name='xxx')
11591170
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
11601171
self.assertEqual(obj.dtype, from_key)
11611172

1173+
if (from_key.startswith('datetime') and to_key.startswith('datetime')):
1174+
# different tz, currently mask_missing raises SystemError
1175+
return
1176+
11621177
if how == 'dict':
11631178
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
11641179
elif how == 'series':
@@ -1175,17 +1190,12 @@ def _assert_replace_conversion(self, from_key, to_key, how):
11751190
pytest.skip("windows platform buggy: {0} -> {1}".format
11761191
(from_key, to_key))
11771192

1178-
if ((from_key == 'float64' and
1179-
to_key in ('bool', 'int64')) or
1180-
1193+
if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
11811194
(from_key == 'complex128' and
11821195
to_key in ('bool', 'int64', 'float64')) or
11831196

1184-
(from_key == 'int64' and
1185-
to_key in ('bool')) or
1186-
1187-
# TODO_GH12747 The result must be int?
1188-
(from_key == 'bool' and to_key == 'int64')):
1197+
# GH12747 The result must be int?
1198+
(from_key == 'int64' and to_key in ('bool'))):
11891199

11901200
# buggy on 32-bit
11911201
if tm.is_platform_32bit():
@@ -1248,13 +1258,31 @@ def test_replace_series_bool(self):
12481258
self._assert_replace_conversion(from_key, to_key, how='series')
12491259

12501260
def test_replace_series_datetime64(self):
1251-
pass
1261+
from_key = 'datetime64[ns]'
1262+
for to_key in self.rep:
1263+
self._assert_replace_conversion(from_key, to_key, how='dict')
1264+
1265+
from_key = 'datetime64[ns]'
1266+
for to_key in self.rep:
1267+
self._assert_replace_conversion(from_key, to_key, how='series')
12521268

12531269
def test_replace_series_datetime64tz(self):
1254-
pass
1270+
from_key = 'datetime64[ns, US/Eastern]'
1271+
for to_key in self.rep:
1272+
self._assert_replace_conversion(from_key, to_key, how='dict')
1273+
1274+
from_key = 'datetime64[ns, US/Eastern]'
1275+
for to_key in self.rep:
1276+
self._assert_replace_conversion(from_key, to_key, how='series')
12551277

12561278
def test_replace_series_timedelta64(self):
1257-
pass
1279+
from_key = 'timedelta64[ns]'
1280+
for to_key in self.rep:
1281+
self._assert_replace_conversion(from_key, to_key, how='dict')
1282+
1283+
from_key = 'timedelta64[ns]'
1284+
for to_key in self.rep:
1285+
self._assert_replace_conversion(from_key, to_key, how='series')
12581286

12591287
def test_replace_series_period(self):
12601288
pass

pandas/tests/series/test_replace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def check_replace(to_rep, val, expected):
132132
tm.assert_series_equal(expected, r)
133133
tm.assert_series_equal(expected, sc)
134134

135-
# should NOT upcast to float
136-
e = pd.Series([0, 1, 2, 3, 4])
135+
# MUST upcast to float
136+
e = pd.Series([0., 1., 2., 3., 4.])
137137
tr, v = [3], [3.0]
138138
check_replace(tr, v, e)
139139

pandas/types/cast.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_ensure_int32, _ensure_int64,
2222
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
2323
_POSSIBLY_CAST_DTYPES)
24-
from .dtypes import ExtensionDtype
24+
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
2525
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
2626
from .missing import isnull, notnull
2727
from .inference import is_list_like
@@ -312,8 +312,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
312312
return dtype, fill_value
313313

314314

315-
def _infer_dtype_from_scalar(val):
316-
""" interpret the dtype from a scalar """
315+
def _infer_dtype_from_scalar(val, pandas_dtype=False):
316+
"""
317+
interpret the dtype from a scalar
318+
319+
Parameters
320+
----------
321+
pandas_dtype : bool, default False
322+
whether to infer dtype including pandas extension types.
323+
If False, scalar belongs to pandas extension types is inferred as
324+
object
325+
"""
317326

318327
dtype = np.object_
319328

@@ -336,13 +345,20 @@ def _infer_dtype_from_scalar(val):
336345

337346
dtype = np.object_
338347

339-
elif isinstance(val, (np.datetime64,
340-
datetime)) and getattr(val, 'tzinfo', None) is None:
341-
val = lib.Timestamp(val).value
342-
dtype = np.dtype('M8[ns]')
348+
elif isinstance(val, (np.datetime64, datetime)):
349+
val = tslib.Timestamp(val)
350+
if val is tslib.NaT or val.tz is None:
351+
dtype = np.dtype('M8[ns]')
352+
else:
353+
if pandas_dtype:
354+
dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
355+
else:
356+
# return datetimetz as object
357+
return np.object_, val
358+
val = val.value
343359

344360
elif isinstance(val, (np.timedelta64, timedelta)):
345-
val = lib.Timedelta(val).value
361+
val = tslib.Timedelta(val).value
346362
dtype = np.dtype('m8[ns]')
347363

348364
elif is_bool(val):
@@ -363,6 +379,11 @@ def _infer_dtype_from_scalar(val):
363379
elif is_complex(val):
364380
dtype = np.complex_
365381

382+
elif pandas_dtype:
383+
if lib.is_period(val):
384+
dtype = PeriodDtype(freq=val.freq)
385+
val = val.ordinal
386+
366387
return dtype, val
367388

368389

0 commit comments

Comments
 (0)