Skip to content

Commit 239c0ff

Browse files
committed
BUG: replace coerces incorrect dtype
1 parent b895968 commit 239c0ff

File tree

5 files changed

+96
-26
lines changed

5 files changed

+96
-26
lines changed

pandas/core/internals.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1890,8 +1890,11 @@ def convert(self, *args, **kwargs):
18901890
blocks.append(newb)
18911891

18921892
else:
1893-
values = fn(
1894-
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
1893+
values = fn(self.values.ravel(), **fn_kwargs)
1894+
try:
1895+
values = values.reshape(self.values.shape)
1896+
except NotImplementedError:
1897+
pass
18951898
blocks.append(make_block(values, ndim=self.ndim,
18961899
placement=self.mgr_locs))
18971900

@@ -3233,6 +3236,16 @@ def comp(s):
32333236
return _possibly_compare(values, getattr(s, 'asm8', s),
32343237
operator.eq)
32353238

3239+
def _cast(block, scalar):
3240+
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
3241+
if not is_dtype_equal(block.dtype, dtype):
3242+
dtype = _find_common_type([block.dtype, dtype])
3243+
block = block.astype(dtype)
3244+
# use original value
3245+
val = scalar
3246+
3247+
return block, val
3248+
32363249
masks = [comp(s) for i, s in enumerate(src_list)]
32373250

32383251
result_blocks = []
@@ -3255,7 +3268,8 @@ def comp(s):
32553268
# particular block
32563269
m = masks[i][b.mgr_locs.indexer]
32573270
if m.any():
3258-
new_rb.extend(b.putmask(m, d, inplace=True))
3271+
b, val = _cast(b, d)
3272+
new_rb.extend(b.putmask(m, val, inplace=True))
32593273
else:
32603274
new_rb.append(b)
32613275
rb = new_rb

pandas/core/missing.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def mask_missing(arr, values_to_mask):
3939
# numpy elementwise comparison warning
4040
if is_numeric_v_string_like(arr, x):
4141
mask = False
42+
# elif is_object_dtype(arr):
43+
# mask = lib.scalar_compare(arr, x, operator.eq)
4244
else:
4345
mask = arr == x
4446

@@ -51,6 +53,8 @@ def mask_missing(arr, values_to_mask):
5153
# numpy elementwise comparison warning
5254
if is_numeric_v_string_like(arr, x):
5355
mask |= False
56+
# elif is_object_dtype(arr):
57+
# mask |= lib.scalar_compare(arr, x, operator.eq)
5458
else:
5559
mask |= arr == x
5660

pandas/tests/indexing/test_coercion.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -1155,12 +1155,27 @@ def setUp(self):
11551155
self.rep['float64'] = [1.1, 2.2]
11561156
self.rep['complex128'] = [1 + 1j, 2 + 2j]
11571157
self.rep['bool'] = [True, False]
1158+
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
1159+
pd.Timestamp('2011-01-03')]
1160+
1161+
for tz in ['UTC', 'US/Eastern']:
1162+
# to test tz => different tz replacement
1163+
key = 'datetime64[ns, {0}]'.format(tz)
1164+
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
1165+
pd.Timestamp('2011-01-03', tz=tz)]
1166+
1167+
self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
1168+
pd.Timedelta('2 day')]
11581169

11591170
def _assert_replace_conversion(self, from_key, to_key, how):
11601171
index = pd.Index([3, 4], name='xxx')
11611172
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
11621173
self.assertEqual(obj.dtype, from_key)
11631174

1175+
if (from_key.startswith('datetime') and to_key.startswith('datetime')):
1176+
# different tz, currently mask_missing raises SystemError
1177+
return
1178+
11641179
if how == 'dict':
11651180
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
11661181
elif how == 'series':
@@ -1177,17 +1192,10 @@ def _assert_replace_conversion(self, from_key, to_key, how):
11771192
raise nose.SkipTest("windows platform buggy: {0} -> {1}".format
11781193
(from_key, to_key))
11791194

1180-
if ((from_key == 'float64' and
1181-
to_key in ('bool', 'int64')) or
1182-
1195+
if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
11831196
(from_key == 'complex128' and
11841197
to_key in ('bool', 'int64', 'float64')) or
1185-
1186-
(from_key == 'int64' and
1187-
to_key in ('bool')) or
1188-
1189-
# TODO_GH12747 The result must be int?
1190-
(from_key == 'bool' and to_key == 'int64')):
1198+
(from_key == 'int64' and to_key in ('bool'))):
11911199

11921200
# buggy on 32-bit
11931201
if tm.is_platform_32bit():
@@ -1250,13 +1258,31 @@ def test_replace_series_bool(self):
12501258
self._assert_replace_conversion(from_key, to_key, how='series')
12511259

12521260
def test_replace_series_datetime64(self):
1253-
pass
1261+
from_key = 'datetime64[ns]'
1262+
for to_key in self.rep:
1263+
self._assert_replace_conversion(from_key, to_key, how='dict')
1264+
1265+
from_key = 'datetime64[ns]'
1266+
for to_key in self.rep:
1267+
self._assert_replace_conversion(from_key, to_key, how='series')
12541268

12551269
def test_replace_series_datetime64tz(self):
1256-
pass
1270+
from_key = 'datetime64[ns, US/Eastern]'
1271+
for to_key in self.rep:
1272+
self._assert_replace_conversion(from_key, to_key, how='dict')
1273+
1274+
from_key = 'datetime64[ns, US/Eastern]'
1275+
for to_key in self.rep:
1276+
self._assert_replace_conversion(from_key, to_key, how='series')
12571277

12581278
def test_replace_series_timedelta64(self):
1259-
pass
1279+
from_key = 'timedelta64[ns]'
1280+
for to_key in self.rep:
1281+
self._assert_replace_conversion(from_key, to_key, how='dict')
1282+
1283+
from_key = 'timedelta64[ns]'
1284+
for to_key in self.rep:
1285+
self._assert_replace_conversion(from_key, to_key, how='series')
12601286

12611287
def test_replace_series_period(self):
12621288
pass

pandas/tests/series/test_replace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def check_replace(to_rep, val, expected):
134134
tm.assert_series_equal(expected, r)
135135
tm.assert_series_equal(expected, sc)
136136

137-
# should NOT upcast to float
138-
e = pd.Series([0, 1, 2, 3, 4])
137+
# MUST upcast to float
138+
e = pd.Series([0., 1., 2., 3., 4.])
139139
tr, v = [3], [3.0]
140140
check_replace(tr, v, e)
141141

pandas/types/cast.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
_ensure_int8, _ensure_int16,
2020
_ensure_int32, _ensure_int64,
2121
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
22-
_POSSIBLY_CAST_DTYPES)
23-
from .dtypes import ExtensionDtype
22+
_DATELIKE_DTYPES, _POSSIBLY_CAST_DTYPES)
23+
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
2424
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
2525
from .missing import isnull, notnull
2626
from .inference import is_list_like
@@ -310,8 +310,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
310310
return dtype, fill_value
311311

312312

313-
def _infer_dtype_from_scalar(val):
314-
""" interpret the dtype from a scalar """
313+
def _infer_dtype_from_scalar(val, pandas_dtype=False):
314+
"""
315+
interpret the dtype from a scalar
316+
317+
Parameters
318+
----------
319+
pandas_dtype : bool, default False
320+
whether to infer dtype including pandas extension types.
321+
If False, scalar belongs to pandas extension types is inferred as
322+
object
323+
"""
315324

316325
dtype = np.object_
317326

@@ -334,13 +343,23 @@ def _infer_dtype_from_scalar(val):
334343

335344
dtype = np.object_
336345

337-
elif isinstance(val, (np.datetime64,
338-
datetime)) and getattr(val, 'tzinfo', None) is None:
339-
val = lib.Timestamp(val).value
340-
dtype = np.dtype('M8[ns]')
346+
elif isinstance(val, (np.datetime64, datetime)):
347+
val = tslib.Timestamp(val)
348+
if val is tslib.NaT or val.tz is None:
349+
dtype = np.dtype('M8[ns]')
350+
else:
351+
if pandas_dtype:
352+
dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
353+
# ToDo: This localization is not needed if
354+
# DatetimeTZBlock doesn't localize internal values
355+
val = val.tz_localize(None)
356+
else:
357+
# return datetimetz as object
358+
return np.object_, val
359+
val = val.value
341360

342361
elif isinstance(val, (np.timedelta64, timedelta)):
343-
val = lib.Timedelta(val).value
362+
val = tslib.Timedelta(val).value
344363
dtype = np.dtype('m8[ns]')
345364

346365
elif is_bool(val):
@@ -361,6 +380,13 @@ def _infer_dtype_from_scalar(val):
361380
elif is_complex(val):
362381
dtype = np.complex_
363382

383+
elif pandas_dtype:
384+
# to do use util
385+
from pandas.tseries.period import Period
386+
if isinstance(val, Period):
387+
dtype = PeriodDtype(freq=val.freq)
388+
val = val.ordinal
389+
364390
return dtype, val
365391

366392

0 commit comments

Comments
 (0)