Skip to content

Commit 2d1a6ea

Browse files
committed
Merge pull request #5177 from jreback/infer
API: make allclose comparison on dtype downcasting (GH5174)
2 parents d147881 + ff93c3a commit 2d1a6ea

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

pandas/core/common.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ def _possibly_downcast_to_dtype(result, dtype):
10531053
# do a test on the first element, if it fails then we are done
10541054
r = result.ravel()
10551055
arr = np.array([ r[0] ])
1056-
if (arr != arr.astype(dtype)).item():
1056+
if not np.allclose(arr,arr.astype(dtype)):
10571057
return result
10581058

10591059
# a comparable, e.g. a Decimal may slip in here
@@ -1062,8 +1062,14 @@ def _possibly_downcast_to_dtype(result, dtype):
10621062

10631063
if issubclass(result.dtype.type, (np.object_,np.number)) and notnull(result).all():
10641064
new_result = result.astype(dtype)
1065-
if (new_result == result).all():
1066-
return new_result
1065+
try:
1066+
if np.allclose(new_result,result):
1067+
return new_result
1068+
except:
1069+
1070+
# comparison of an object dtype with a number type could hit here
1071+
if (new_result == result).all():
1072+
return new_result
10671073
except:
10681074
pass
10691075

pandas/core/internals.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,10 @@ def downcast(self, dtypes=None):
376376
dtype = dtypes.get(item, self._downcast_dtype)
377377

378378
if dtype is None:
379-
nv = _block_shape(values[i])
379+
nv = _block_shape(values[i],ndim=self.ndim)
380380
else:
381381
nv = _possibly_downcast_to_dtype(values[i], dtype)
382-
nv = _block_shape(nv)
382+
nv = _block_shape(nv,ndim=self.ndim)
383383

384384
blocks.append(make_block(nv, Index([item]), self.ref_items, ndim=self.ndim, fastpath=True))
385385

pandas/tests/test_generic.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _axes(self):
5959
""" return the axes for my object typ """
6060
return self._typ._AXIS_ORDERS
6161

62-
def _construct(self, shape, value=None, **kwargs):
62+
def _construct(self, shape, value=None, dtype=None, **kwargs):
6363
""" construct an object for the given shape
6464
if value is specified use that if its a scalar
6565
if value is an array, repeat it as needed """
@@ -74,7 +74,7 @@ def _construct(self, shape, value=None, **kwargs):
7474
# remove the info axis
7575
kwargs.pop(self._typ._info_axis_name,None)
7676
else:
77-
arr = np.empty(shape)
77+
arr = np.empty(shape,dtype=dtype)
7878
arr.fill(value)
7979
else:
8080
fshape = np.prod(shape)
@@ -184,6 +184,32 @@ def test_numpy_1_7_compat_numeric_methods(self):
184184
if f is not None:
185185
f(o)
186186

187+
def test_downcast(self):
188+
# test close downcasting
189+
190+
o = self._construct(shape=4, value=9, dtype=np.int64)
191+
result = o.copy()
192+
result._data = o._data.downcast(dtypes='infer')
193+
self._compare(result, o)
194+
195+
o = self._construct(shape=4, value=9.)
196+
expected = o.astype(np.int64)
197+
result = o.copy()
198+
result._data = o._data.downcast(dtypes='infer')
199+
self._compare(result, expected)
200+
201+
o = self._construct(shape=4, value=9.5)
202+
result = o.copy()
203+
result._data = o._data.downcast(dtypes='infer')
204+
self._compare(result, o)
205+
206+
# are close
207+
o = self._construct(shape=4, value=9.000000000005)
208+
result = o.copy()
209+
result._data = o._data.downcast(dtypes='infer')
210+
expected = o.astype(np.int64)
211+
self._compare(result, expected)
212+
187213
class TestSeries(unittest.TestCase, Generic):
188214
_typ = Series
189215
_comparator = lambda self, x, y: assert_series_equal(x,y)
@@ -335,7 +361,7 @@ def test_interp_quad(self):
335361
_skip_if_no_scipy()
336362
sq = Series([1, 4, np.nan, 16], index=[1, 2, 3, 4])
337363
result = sq.interpolate(method='quadratic')
338-
expected = Series([1., 4., 9., 16.], index=[1, 2, 3, 4])
364+
expected = Series([1, 4, 9, 16], index=[1, 2, 3, 4])
339365
assert_series_equal(result, expected)
340366

341367
def test_interp_scipy_basic(self):
@@ -589,7 +615,7 @@ def test_spline(self):
589615
_skip_if_no_scipy()
590616
s = Series([1, 2, np.nan, 4, 5, np.nan, 7])
591617
result = s.interpolate(method='spline', order=1)
592-
expected = Series([1., 2, 3, 4, 5, 6, 7]) # dtype?
618+
expected = Series([1, 2, 3, 4, 5, 6, 7])
593619
assert_series_equal(result, expected)
594620

595621

0 commit comments

Comments
 (0)