diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 7a9ba40e..3d961e2e 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -543,6 +543,8 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: raise ValueError(err_msg) # no scalars in `where` - array-api#807 y = xp.pi * xp.where( - x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device) + xp.astype(x, xp.bool), + x, + xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device), ) return xp.sin(y) / y