Skip to content

Commit 476acdb

Browse files
committed
BUG: sum() of a bool array is integer
1 parent 69be71f commit 476acdb

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

torch_np/_helpers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,14 @@ def to_tensors(*inputs):
133133
for value in inputs])
134134

135135

136-
def float_or_default(dtype, self_dtype):
136+
def float_or_default(dtype, self_dtype, enforce_float=False):
137137
"""dtype helper for reductions."""
138138
if dtype is None:
139139
dtype = self_dtype
140-
if _dtypes.is_integer(dtype):
141-
dtype = _dtypes.default_float_type()
140+
if dtype == _dtypes.dtype('bool'):
141+
dtype = _dtypes.default_int_type()
142+
if enforce_float:
143+
if _dtypes.is_integer(dtype):
144+
dtype = _dtypes.default_float_type()
142145
torch_dtype = _dtypes.torch_dtype_from(dtype)
143146
return torch_dtype

torch_np/_ndarray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
416416
if where is not None:
417417
raise NotImplementedError
418418

419-
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
419+
torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True)
420420
if axis is None:
421421
result = self._tensor.mean(dtype=torch_dtype)
422422
else:
@@ -462,7 +462,7 @@ def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *,
462462
if where is not None:
463463
raise NotImplementedError
464464

465-
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
465+
torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True)
466466
tensor = self._tensor.to(torch_dtype) # XXX: needed?
467467

468468
result = tensor.std(dim=axis, correction=ddof)
@@ -475,7 +475,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *,
475475
if where is not None:
476476
raise NotImplementedError
477477

478-
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
478+
torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True)
479479
tensor = self._tensor.to(torch_dtype) # XXX: needed?
480480

481481
result = tensor.var(dim=axis, correction=ddof)

torch_np/tests/test_reductions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,16 @@ def test_sum_stability(self):
328328
assert_allclose((a / 10.).sum() - a.size / 10., 0., atol=1.5e-13,
329329
check_dtype=False)
330330

331+
def test_sum_boolean(self):
332+
a = (np.arange(7) % 2 == 0)
333+
res = a.sum()
334+
assert_equal(res, 4)
335+
336+
res_float = a.sum(dtype=np.float64)
337+
assert_allclose(res_float, 4.0, atol=1e-15)
338+
assert res_float.dtype == 'float64'
339+
340+
331341
@pytest.mark.xfail(reason="dtype(value) needs implementing")
332342
def test_sum_dtypes(self):
333343
for dt in (int, np.float16, np.float32, np.float64):

0 commit comments

Comments
 (0)