Skip to content

Commit 081d2e9

Browse files
committed
BUG: replace coerces incorrect dtype
1 parent f26b049 commit 081d2e9

File tree

5 files changed

+95
-25
lines changed

5 files changed

+95
-25
lines changed

pandas/core/internals.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1873,8 +1873,11 @@ def convert(self, *args, **kwargs):
18731873
blocks.append(newb)
18741874

18751875
else:
1876-
values = fn(
1877-
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
1876+
values = fn(self.values.ravel(), **fn_kwargs)
1877+
try:
1878+
values = values.reshape(self.values.shape)
1879+
except NotImplementedError:
1880+
pass
18781881
blocks.append(make_block(values, ndim=self.ndim,
18791882
placement=self.mgr_locs))
18801883

@@ -3211,6 +3214,16 @@ def comp(s):
32113214
return _possibly_compare(values, getattr(s, 'asm8', s),
32123215
operator.eq)
32133216

3217+
def _cast(block, scalar):
3218+
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
3219+
if not is_dtype_equal(block.dtype, dtype):
3220+
dtype = _find_common_type([block.dtype, dtype])
3221+
block = block.astype(dtype)
3222+
# use original value
3223+
val = scalar
3224+
3225+
return block, val
3226+
32143227
masks = [comp(s) for i, s in enumerate(src_list)]
32153228

32163229
result_blocks = []
@@ -3231,7 +3244,8 @@ def comp(s):
32313244
# particular block
32323245
m = masks[i][b.mgr_locs.indexer]
32333246
if m.any():
3234-
new_rb.extend(b.putmask(m, d, inplace=True))
3247+
b, val = _cast(b, d)
3248+
new_rb.extend(b.putmask(m, val, inplace=True))
32353249
else:
32363250
new_rb.append(b)
32373251
rb = new_rb

pandas/core/missing.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def mask_missing(arr, values_to_mask):
3939
# numpy elementwise comparison warning
4040
if is_numeric_v_string_like(arr, x):
4141
mask = False
42+
# elif is_object_dtype(arr):
43+
# mask = lib.scalar_compare(arr, x, operator.eq)
4244
else:
4345
mask = arr == x
4446

@@ -51,6 +53,8 @@ def mask_missing(arr, values_to_mask):
5153
# numpy elementwise comparison warning
5254
if is_numeric_v_string_like(arr, x):
5355
mask |= False
56+
# elif is_object_dtype(arr):
57+
# mask |= lib.scalar_compare(arr, x, operator.eq)
5458
else:
5559
mask |= arr == x
5660

pandas/tests/indexing/test_coercion.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -1155,12 +1155,27 @@ def setUp(self):
11551155
self.rep['float64'] = [1.1, 2.2]
11561156
self.rep['complex128'] = [1 + 1j, 2 + 2j]
11571157
self.rep['bool'] = [True, False]
1158+
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
1159+
pd.Timestamp('2011-01-03')]
1160+
1161+
for tz in ['UTC', 'US/Eastern']:
1162+
# to test tz => different tz replacement
1163+
key = 'datetime64[ns, {0}]'.format(tz)
1164+
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
1165+
pd.Timestamp('2011-01-03', tz=tz)]
1166+
1167+
self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
1168+
pd.Timedelta('2 day')]
11581169

11591170
def _assert_replace_conversion(self, from_key, to_key, how):
11601171
index = pd.Index([3, 4], name='xxx')
11611172
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
11621173
self.assertEqual(obj.dtype, from_key)
11631174

1175+
if (from_key.startswith('datetime') and to_key.startswith('datetime')):
1176+
# different tz, currently mask_missing raises SystemError
1177+
return
1178+
11641179
if how == 'dict':
11651180
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
11661181
elif how == 'series':
@@ -1177,17 +1192,10 @@ def _assert_replace_conversion(self, from_key, to_key, how):
11771192
raise nose.SkipTest("windows platform buggy: {0} -> {1}".format
11781193
(from_key, to_key))
11791194

1180-
if ((from_key == 'float64' and
1181-
to_key in ('bool', 'int64')) or
1182-
1195+
if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
11831196
(from_key == 'complex128' and
11841197
to_key in ('bool', 'int64', 'float64')) or
1185-
1186-
(from_key == 'int64' and
1187-
to_key in ('bool')) or
1188-
1189-
# TODO_GH12747 The result must be int?
1190-
(from_key == 'bool' and to_key == 'int64')):
1198+
(from_key == 'int64' and to_key in ('bool'))):
11911199

11921200
# buggy on 32-bit
11931201
if tm.is_platform_32bit():
@@ -1250,13 +1258,31 @@ def test_replace_series_bool(self):
12501258
self._assert_replace_conversion(from_key, to_key, how='series')
12511259

12521260
def test_replace_series_datetime64(self):
1253-
pass
1261+
from_key = 'datetime64[ns]'
1262+
for to_key in self.rep:
1263+
self._assert_replace_conversion(from_key, to_key, how='dict')
1264+
1265+
from_key = 'datetime64[ns]'
1266+
for to_key in self.rep:
1267+
self._assert_replace_conversion(from_key, to_key, how='series')
12541268

12551269
def test_replace_series_datetime64tz(self):
1256-
pass
1270+
from_key = 'datetime64[ns, US/Eastern]'
1271+
for to_key in self.rep:
1272+
self._assert_replace_conversion(from_key, to_key, how='dict')
1273+
1274+
from_key = 'datetime64[ns, US/Eastern]'
1275+
for to_key in self.rep:
1276+
self._assert_replace_conversion(from_key, to_key, how='series')
12571277

12581278
def test_replace_series_timedelta64(self):
1259-
pass
1279+
from_key = 'timedelta64[ns]'
1280+
for to_key in self.rep:
1281+
self._assert_replace_conversion(from_key, to_key, how='dict')
1282+
1283+
from_key = 'timedelta64[ns]'
1284+
for to_key in self.rep:
1285+
self._assert_replace_conversion(from_key, to_key, how='series')
12601286

12611287
def test_replace_series_period(self):
12621288
pass

pandas/tests/series/test_replace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def check_replace(to_rep, val, expected):
134134
tm.assert_series_equal(expected, r)
135135
tm.assert_series_equal(expected, sc)
136136

137-
# should NOT upcast to float
138-
e = pd.Series([0, 1, 2, 3, 4])
137+
# MUST upcast to float
138+
e = pd.Series([0., 1., 2., 3., 4.])
139139
tr, v = [3], [3.0]
140140
check_replace(tr, v, e)
141141

pandas/types/cast.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
_ensure_int32, _ensure_int64,
2020
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
2121
_DATELIKE_DTYPES, _POSSIBLY_CAST_DTYPES)
22-
from .dtypes import ExtensionDtype
22+
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
2323
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
2424
from .missing import isnull, notnull
2525
from .inference import is_list_like
@@ -309,8 +309,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
309309
return dtype, fill_value
310310

311311

312-
def _infer_dtype_from_scalar(val):
313-
""" interpret the dtype from a scalar """
312+
def _infer_dtype_from_scalar(val, pandas_dtype=False):
313+
"""
314+
interpret the dtype from a scalar
315+
316+
Parameters
317+
----------
318+
pandas_dtype : bool, default False
319+
whether to infer dtype including pandas extension types.
320+
If False, scalar belongs to pandas extension types is inferred as
321+
object
322+
"""
314323

315324
dtype = np.object_
316325

@@ -333,13 +342,23 @@ def _infer_dtype_from_scalar(val):
333342

334343
dtype = np.object_
335344

336-
elif isinstance(val, (np.datetime64,
337-
datetime)) and getattr(val, 'tzinfo', None) is None:
338-
val = lib.Timestamp(val).value
339-
dtype = np.dtype('M8[ns]')
345+
elif isinstance(val, (np.datetime64, datetime)):
346+
val = tslib.Timestamp(val)
347+
if val is tslib.NaT or val.tz is None:
348+
dtype = np.dtype('M8[ns]')
349+
else:
350+
if pandas_dtype:
351+
dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
352+
# ToDo: This localization is not needed if
353+
# DatetimeTZBlock doesn't localize internal values
354+
val = val.tz_localize(None)
355+
else:
356+
# return datetimetz as object
357+
return np.object_, val
358+
val = val.value
340359

341360
elif isinstance(val, (np.timedelta64, timedelta)):
342-
val = lib.Timedelta(val).value
361+
val = tslib.Timedelta(val).value
343362
dtype = np.dtype('m8[ns]')
344363

345364
elif is_bool(val):
@@ -360,6 +379,13 @@ def _infer_dtype_from_scalar(val):
360379
elif is_complex(val):
361380
dtype = np.complex_
362381

382+
elif pandas_dtype:
383+
# to do use util
384+
from pandas.tseries.period import Period
385+
if isinstance(val, Period):
386+
dtype = PeriodDtype(freq=val.freq)
387+
val = val.ordinal
388+
363389
return dtype, val
364390

365391

0 commit comments

Comments
 (0)