Skip to content

Commit 9e2e9f2

Browse files
committed
MAINT: split bincount
1 parent b9a8ab5 commit 9e2e9f2

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

torch_np/_detail/implementations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,12 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
323323
output = [x.clone() for x in output]
324324

325325
return output
326+
327+
328+
329+
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
330+
int_dtype = _dtypes_impl.default_int_dtype
331+
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
332+
333+
result = torch.bincount(x_tensor, weights_tensor, minlength)
334+
return result

torch_np/_wrapper.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,7 @@ def bincount(x, /, weights=None, minlength=0):
498498
x = asarray([], dtype=int)
499499

500500
x_tensor, weights_tensor = _helpers.to_tensors_or_none(x, weights)
501-
int_dtype = _dtypes_impl.default_int_dtype
502-
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
503-
504-
result = torch.bincount(x_tensor, weights_tensor, minlength)
501+
result = _impl.bincount(x_tensor, weights_tensor, minlength)
505502
return asarray(result)
506503

507504

0 commit comments

Comments
 (0)