Skip to content

Commit d70e9aa

Browse files
committed
MAINT: address review comments
1 parent f26148f commit d70e9aa

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

torch_np/_detail/_reductions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
173173
# 2. axis=None ravels (cf concatenate)
174174

175175

176-
def cumprod(tensor, axis=None, dtype=None):
176+
def cumprod(tensor, axis, dtype=None):
177177
if dtype == torch.bool:
178178
dtype = _scalar_types.default_int_type.dtype
179179
if dtype is None:
@@ -184,7 +184,7 @@ def cumprod(tensor, axis=None, dtype=None):
184184
return result
185185

186186

187-
def cumsum(tensor, axis=None, dtype=None):
187+
def cumsum(tensor, axis, dtype=None):
188188
if dtype == torch.bool:
189189
dtype = _scalar_types.default_int_type.dtype
190190
if dtype is None:

torch_np/_detail/_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def expand_shape(arr_shape, axis):
111111
def apply_keepdims(tensor, axis, ndim):
112112
if axis is None:
113113
# tensor was a scalar
114-
tensor = torch.full((1,) * ndim, fill_value=tensor.item(), dtype=tensor.dtype)
114+
shape = (1,) * ndim
115+
tensor = tensor.expand(shape).contiguous() # avoid CUDA synchronization
115116
else:
116117
shape = expand_shape(tensor.shape, axis)
117118
tensor = tensor.reshape(shape)

torch_np/_helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
5656
if promote_scalar and can_fit:
5757
result_tensor = result_tensor.squeeze()
5858
else:
59-
raise ValueError("Bad size of the out array.")
59+
raise ValueError(
60+
f"Bad size of the out array: out.shape = {out.shape} "
61+
f"while result.shape = {result_tensor.shape}."
62+
)
6063
out_tensor = out_array.get()
6164
out_tensor.copy_(result_tensor)
6265
return out_array

0 commit comments

Comments
 (0)