Skip to content

Commit e7e1fb5

Browse files
committed
Fix torch std() with integral float correction
Fixes #24
1 parent 78dc3e5 commit e7e1fb5

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

array_api_compat/torch/_aliases.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,10 @@ def std(x: array,
361361
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
362362
# implement it here for now.
363363

364-
# if isinstance(correction, float):
365-
# correction = int(correction)
364+
if isinstance(correction, float):
365+
_correction = int(correction)
366+
if correction != _correction:
367+
raise NotImplementedError("float correction in torch std() is not yet supported")
366368

367369
# https://github.com/pytorch/pytorch/issues/29137
368370
if axis == ():
@@ -372,10 +374,10 @@ def std(x: array,
372374
if axis is None:
373375
# torch doesn't support keepdims with axis=None
374376
# (https://github.com/pytorch/pytorch/issues/71209)
375-
res = torch.std(x, tuple(range(x.ndim)), correction=correction, **kwargs)
377+
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
376378
res = _axis_none_keepdims(res, x.ndim, keepdims)
377379
return res
378-
return torch.std(x, axis, correction=correction, keepdims=keepdims, **kwargs)
380+
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
379381

380382
def var(x: array,
381383
/,

0 commit comments

Comments
 (0)