diff --git a/torch_np/_detail/_reductions.py b/torch_np/_detail/_reductions.py index 1b99af28..72b356cd 100644 --- a/torch_np/_detail/_reductions.py +++ b/torch_np/_detail/_reductions.py @@ -4,17 +4,13 @@ Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc """ +import functools import typing import torch from . import _dtypes_impl, _util -NoValue = _util.NoValue - - -import functools - ############# XXX ### From _util.axis_expand_func @@ -51,7 +47,7 @@ def wrapped(tensor, axis, *args, **kwds): def emulate_keepdims(func): @functools.wraps(func) - def wrapped(tensor, axis=None, keepdims=NoValue, *args, **kwds): + def wrapped(tensor, axis=None, keepdims=None, *args, **kwds): result = func(tensor, axis=axis, *args, **kwds) if keepdims: result = _util.apply_keepdims(result, axis, tensor.ndim) @@ -133,10 +129,7 @@ def argmin(tensor, axis=None): @emulate_keepdims @deco_axis_expand -def any(tensor, axis=None, *, where=NoValue): - if where is not NoValue: - raise NotImplementedError - +def any(tensor, axis=None, *, where=None): axis = _util.allow_only_single_axis(axis) if axis is None: @@ -148,10 +141,7 @@ def any(tensor, axis=None, *, where=NoValue): @emulate_keepdims @deco_axis_expand -def all(tensor, axis=None, *, where=NoValue): - if where is not NoValue: - raise NotImplementedError - +def all(tensor, axis=None, *, where=None): axis = _util.allow_only_single_axis(axis) if axis is None: @@ -163,37 +153,25 @@ def all(tensor, axis=None, *, where=NoValue): @emulate_keepdims @deco_axis_expand -def max(tensor, axis=None, initial=NoValue, where=NoValue): - if initial is not NoValue or where is not NoValue: - raise NotImplementedError - - result = tensor.amax(axis) - return result +def max(tensor, axis=None, initial=None, where=None): + return tensor.amax(axis) @emulate_keepdims @deco_axis_expand -def min(tensor, axis=None, initial=NoValue, where=NoValue): - if initial is not NoValue or where is not NoValue: - raise NotImplementedError - - result = tensor.amin(axis) - return result +def min(tensor, axis=None, initial=None, where=None): + return tensor.amin(axis) @emulate_keepdims @deco_axis_expand def ptp(tensor, axis=None): - result = tensor.amax(axis) - tensor.amin(axis) - return result + return tensor.amax(axis) - tensor.amin(axis) @emulate_keepdims @deco_axis_expand -def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): - if initial is not NoValue or where is not NoValue: - raise NotImplementedError - +def sum(tensor, axis=None, dtype=None, initial=None, where=None): assert dtype is None or isinstance(dtype, torch.dtype) if dtype == torch.bool: @@ -209,10 +187,7 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): @emulate_keepdims @deco_axis_expand -def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): - if initial is not NoValue or where is not NoValue: - raise NotImplementedError - +def prod(tensor, axis=None, dtype=None, initial=None, where=None): axis = _util.allow_only_single_axis(axis) if dtype == torch.bool: @@ -228,10 +203,7 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): @emulate_keepdims @deco_axis_expand -def mean(tensor, axis=None, dtype=None, *, where=NoValue): - if where is not NoValue: - raise NotImplementedError - +def mean(tensor, axis=None, dtype=None, *, where=None): dtype = _atleast_float(dtype, tensor.dtype) is_half = dtype == torch.float16 @@ -252,10 +224,7 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue): @emulate_keepdims @deco_axis_expand -def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): - if where is not NoValue: - raise NotImplementedError - +def std(tensor, axis=None, dtype=None, ddof=0, *, where=None): dtype = _atleast_float(dtype, tensor.dtype) tensor = _util.cast_if_needed(tensor, dtype) result = tensor.std(dim=axis, correction=ddof) @@ -265,10 +234,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): @emulate_keepdims @deco_axis_expand -def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): - if where is not NoValue: - raise NotImplementedError - +def var(tensor, axis=None, dtype=None, ddof=0, *, where=None): dtype = _atleast_float(dtype, tensor.dtype) tensor = _util.cast_if_needed(tensor, dtype) result = tensor.var(dim=axis, correction=ddof) @@ -387,9 +353,6 @@ def quantile( # Here we choose to work out-of-place because why not. pass - if interpolation is not None: - raise ValueError("'interpolation' argument is deprecated; use 'method' instead") - if not a.dtype.is_floating_point: dtype = _dtypes_impl.default_float_dtype a = a.to(dtype) diff --git a/torch_np/_detail/_util.py b/torch_np/_detail/_util.py index 33b99322..8d4c0be2 100644 --- a/torch_np/_detail/_util.py +++ b/torch_np/_detail/_util.py @@ -7,7 +7,6 @@ from . import _dtypes_impl -NoValue = None # https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 def is_sequence(seq): diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index cdf7e7b3..decb649b 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -23,26 +23,26 @@ AxisLike, DTypeLike, NDArray, + NotImplementedType, OutArray, - SubokLike, normalize_array_like, ) -NoValue = _util.NoValue - - ###### array creation routines -def copy(a: ArrayLike, order="K", subok: SubokLike = False): - if order != "K": - raise NotImplementedError +def copy( + a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False +): return a.clone() -def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue): - if where is not NoValue: - raise NotImplementedError +def copyto( + dst: NDArray, + src: ArrayLike, + casting="same_kind", + where: NotImplementedType = None, +): (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting) dst.copy_(src) @@ -323,7 +323,7 @@ def arange( step: Optional[ArrayLike] = 1, dtype: DTypeLike = None, *, - like: SubokLike = None, + like: NotImplementedType = None, ): if step == 0: raise ZeroDivisionError @@ -365,9 +365,13 @@ def arange( # ### zeros/ones/empty/full ### -def empty(shape, dtype: DTypeLike = float, order="C", *, like: SubokLike = None): - if order != "C": - raise NotImplementedError +def empty( + shape, + dtype: DTypeLike = float, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): if dtype is None: dtype = _dtypes_impl.default_float_dtype return torch.empty(shape, dtype=dtype) @@ -380,12 +384,10 @@ def empty(shape, dtype: DTypeLike = float, order="C", *, like: SubokLike = None) def empty_like( prototype: ArrayLike, dtype: DTypeLike = None, - order="K", - subok: SubokLike = False, + order: NotImplementedType = "K", + subok: NotImplementedType = False, shape=None, ): - if order != "K": - raise NotImplementedError result = torch.empty_like(prototype, dtype=dtype) if shape is not None: result = result.reshape(shape) @@ -396,14 +398,12 @@ def full( shape, fill_value: ArrayLike, dtype: DTypeLike = None, - order="C", + order: NotImplementedType = "C", *, - like: SubokLike = None, + like: NotImplementedType = None, ): if isinstance(shape, int): shape = (shape,) - if order != "C": - raise NotImplementedError if dtype is None: dtype = fill_value.dtype if not isinstance(shape, (tuple, list)): @@ -415,12 +415,10 @@ def full_like( a: ArrayLike, fill_value, dtype: DTypeLike = None, - order="K", - subok: SubokLike = False, + order: NotImplementedType = "K", + subok: NotImplementedType = False, shape=None, ): - if order != "K": - raise NotImplementedError # XXX: fill_value broadcasts result = torch.full_like(a, fill_value, dtype=dtype) if shape is not None: @@ -428,9 +426,13 @@ def full_like( return result -def ones(shape, dtype: DTypeLike = None, order="C", *, like: SubokLike = None): - if order != "C": - raise NotImplementedError +def ones( + shape, + dtype: DTypeLike = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): if dtype is None: dtype = _dtypes_impl.default_float_dtype return torch.ones(shape, dtype=dtype) @@ -439,21 +441,23 @@ def ones(shape, dtype: DTypeLike = None, order="C", *, like: SubokLike = None): def ones_like( a: ArrayLike, dtype: DTypeLike = None, - order="K", - subok: SubokLike = False, + order: NotImplementedType = "K", + subok: NotImplementedType = False, shape=None, ): - if order != "K": - raise NotImplementedError result = torch.ones_like(a, dtype=dtype) if shape is not None: result = result.reshape(shape) return result -def zeros(shape, dtype: DTypeLike = None, order="C", *, like: SubokLike = None): - if order != "C": - raise NotImplementedError +def zeros( + shape, + dtype: DTypeLike = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): if dtype is None: dtype = _dtypes_impl.default_float_dtype return torch.zeros(shape, dtype=dtype) @@ -462,12 +466,10 @@ def zeros(shape, dtype: DTypeLike = None, order="C", *, like: SubokLike = None): def zeros_like( a: ArrayLike, dtype: DTypeLike = None, - order="K", - subok: SubokLike = False, + order: NotImplementedType = "K", + subok: NotImplementedType = False, shape=None, ): - if order != "K": - raise NotImplementedError result = torch.zeros_like(a, dtype=dtype) if shape is not None: result = result.reshape(shape) @@ -506,8 +508,8 @@ def corrcoef( x: ArrayLike, y: Optional[ArrayLike] = None, rowvar=True, - bias=NoValue, - ddof=NoValue, + bias=None, + ddof=None, *, dtype: DTypeLike = None, ): @@ -648,14 +650,14 @@ def rot90(m: ArrayLike, k=1, axes=(0, 1)): # ### broadcasting and indices ### -def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False): +def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): return torch.broadcast_to(array, size=shape) from torch import broadcast_shapes -def broadcast_arrays(*args: ArrayLike, subok: SubokLike = False): +def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False): return torch.broadcast_tensors(*args) @@ -741,7 +743,7 @@ def triu_indices_from(arr: ArrayLike, k=0): return tuple(result) -def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: SubokLike = None): +def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: NotImplementedType = None): if M is None: M = N tensor = torch.ones((N, M), dtype=dtype) @@ -757,13 +759,11 @@ def nanmean( axis=None, dtype: DTypeLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, + keepdims=None, *, - where=NoValue, + where: NotImplementedType = None, ): # XXX: this needs to be rewritten - if where is not NoValue: - raise NotImplementedError if dtype is None: dtype = a.dtype if axis is None: @@ -895,11 +895,8 @@ def take( indices: ArrayLike, axis=None, out: Optional[OutArray] = None, - mode="raise", + mode: NotImplementedType = "raise", ): - if mode != "raise": - raise NotImplementedError(f"{mode=}") - (a,), axis = _util.axis_none_ravel(a, axis=axis) axis = _util.normalize_axis_index(axis, a.ndim) idx = (slice(None),) * axis + (indices, ...) @@ -929,16 +926,13 @@ def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): def unique( ar: ArrayLike, - return_index=False, + return_index: NotImplementedType = False, return_inverse=False, return_counts=False, axis=None, *, - equal_nan=True, + equal_nan: NotImplementedType = True, ): - if return_index or not equal_nan: - raise NotImplementedError - if axis is None: ar = ar.ravel() axis = 0 @@ -1078,9 +1072,15 @@ def trace( return result -def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None): - if order != "C": - raise NotImplementedError +def eye( + N, + M=None, + k=0, + dtype: DTypeLike = float, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): if M is None: M = N z = torch.zeros(N, M, dtype=dtype) @@ -1088,7 +1088,7 @@ def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike return z -def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None): +def identity(n, dtype: DTypeLike = None, *, like: NotImplementedType = None): return torch.eye(n, dtype=dtype) @@ -1225,12 +1225,6 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): def _sort_helper(tensor, axis, kind, order): - if order is not None: - # only relevant for structured dtypes; not supported - raise NotImplementedError( - "'order' keyword is only relevant for structured dtypes" - ) - (tensor,), axis = _util.axis_none_ravel(tensor, axis=axis) axis = _util.normalize_axis_index(axis, tensor.ndim) @@ -1239,13 +1233,14 @@ def _sort_helper(tensor, axis, kind, order): return tensor, axis, stable -def sort(a: ArrayLike, axis=-1, kind=None, order=None): +def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): + # `order` keyword arg is only relevant for structured dtypes; so not supported here. a, axis, stable = _sort_helper(a, axis, kind, order) result = torch.sort(a, dim=axis, stable=stable) return result.values -def argsort(a: ArrayLike, axis=-1, kind=None, order=None): +def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): a, axis, stable = _sort_helper(a, axis, kind, order) return torch.argsort(a, dim=axis, stable=stable) @@ -1324,9 +1319,7 @@ def squeeze(a: ArrayLike, axis=None): return result -def reshape(a: ArrayLike, newshape, order="C"): - if order != "C": - raise NotImplementedError +def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"): # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) newshape = newshape[0] if len(newshape) == 1 else newshape return a.reshape(newshape) @@ -1352,18 +1345,14 @@ def transpose(a: ArrayLike, axes=None): return result -def ravel(a: ArrayLike, order="C"): - if order != "C": - raise NotImplementedError +def ravel(a: ArrayLike, order: NotImplementedType = "C"): return torch.ravel(a) # leading underscore since arr.flatten exists but np.flatten does not -def _flatten(a: ArrayLike, order="C"): - if order != "C": - raise NotImplementedError +def _flatten(a: ArrayLike, order: NotImplementedType = "C"): # may return a copy return torch.flatten(a) @@ -1411,9 +1400,9 @@ def sum( axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, - initial=NoValue, - where=NoValue, + keepdims=None, + initial: NotImplementedType = None, + where: NotImplementedType = None, ): result = _impl.sum( a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims @@ -1426,9 +1415,9 @@ def prod( axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, - initial=NoValue, - where=NoValue, + keepdims=None, + initial: NotImplementedType = None, + where: NotImplementedType = None, ): result = _impl.prod( a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims @@ -1444,11 +1433,11 @@ def mean( axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, + keepdims=None, *, - where=NoValue, + where: NotImplementedType = None, ): - result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims) + result = _impl.mean(a, axis=axis, dtype=dtype, where=None, keepdims=keepdims) return result @@ -1458,9 +1447,9 @@ def var( dtype: DTypeLike = None, out: Optional[OutArray] = None, ddof=0, - keepdims=NoValue, + keepdims=None, *, - where=NoValue, + where: NotImplementedType = None, ): result = _impl.var( a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims @@ -1474,9 +1463,9 @@ def std( dtype: DTypeLike = None, out: Optional[OutArray] = None, ddof=0, - keepdims=NoValue, + keepdims=None, *, - where=NoValue, + where: NotImplementedType = None, ): result = _impl.std( a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims @@ -1489,7 +1478,7 @@ def argmin( axis: AxisLike = None, out: Optional[OutArray] = None, *, - keepdims=NoValue, + keepdims=None, ): result = _impl.argmin(a, axis=axis, keepdims=keepdims) return result @@ -1500,7 +1489,7 @@ def argmax( axis: AxisLike = None, out: Optional[OutArray] = None, *, - keepdims=NoValue, + keepdims=None, ): result = _impl.argmax(a, axis=axis, keepdims=keepdims) return result @@ -1510,9 +1499,9 @@ def amax( a: ArrayLike, axis: AxisLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, - initial=NoValue, - where=NoValue, + keepdims=None, + initial: NotImplementedType = None, + where: NotImplementedType = None, ): result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims) return result @@ -1525,9 +1514,9 @@ def amin( a: ArrayLike, axis: AxisLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, - initial=NoValue, - where=NoValue, + keepdims=None, + initial: NotImplementedType = None, + where: NotImplementedType = None, ): result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims) return result @@ -1540,7 +1529,7 @@ def ptp( a: ArrayLike, axis: AxisLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, + keepdims=None, ): result = _impl.ptp(a, axis=axis, keepdims=keepdims) return result @@ -1550,9 +1539,9 @@ def all( a: ArrayLike, axis: AxisLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, + keepdims=None, *, - where=NoValue, + where: NotImplementedType = None, ): result = _impl.all(a, axis=axis, where=where, keepdims=keepdims) return result @@ -1562,9 +1551,9 @@ def any( a: ArrayLike, axis: AxisLike = None, out: Optional[OutArray] = None, - keepdims=NoValue, + keepdims=None, *, - where=NoValue, + where: NotImplementedType = None, ): result = _impl.any(a, axis=axis, where=where, keepdims=keepdims) return result @@ -1607,7 +1596,7 @@ def quantile( method="linear", keepdims=False, *, - interpolation=None, + interpolation: NotImplementedType = None, ): result = _impl.quantile( a, @@ -1630,7 +1619,7 @@ def percentile( method="linear", keepdims=False, *, - interpolation=None, + interpolation: NotImplementedType = None, ): result = _impl.percentile( a, @@ -1667,7 +1656,7 @@ def average( weights: ArrayLike = None, returned=False, *, - keepdims=NoValue, + keepdims=None, ): result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims) if returned: @@ -1680,8 +1669,8 @@ def diff( a: ArrayLike, n=1, axis=-1, - prepend: Optional[ArrayLike] = NoValue, - append: Optional[ArrayLike] = NoValue, + prepend: Optional[ArrayLike] = None, + append: Optional[ArrayLike] = None, ): axis = _util.normalize_axis_index(axis, a.ndim) diff --git a/torch_np/_normalizations.py b/torch_np/_normalizations.py index 6a3bb02f..5bd16525 100644 --- a/torch_np/_normalizations.py +++ b/torch_np/_normalizations.py @@ -12,7 +12,6 @@ ArrayLike = typing.TypeVar("ArrayLike") DTypeLike = typing.TypeVar("DTypeLike") -SubokLike = typing.TypeVar("SubokLike") AxisLike = typing.TypeVar("AxisLike") NDArray = typing.TypeVar("NDarray") @@ -29,29 +28,34 @@ # OutArray = typing.TypeVar("OutArray") +try: + from typing import NotImplementedType +except ImportError: + NotImplementedType = typing.TypeVar("NotImplementedType") + import inspect from . import _dtypes -def normalize_array_like(x, name=None): +def normalize_array_like(x, parm=None): from ._ndarray import asarray return asarray(x).tensor -def normalize_optional_array_like(x, name=None): +def normalize_optional_array_like(x, parm=None): # This explicit normalizer is needed because otherwise normalize_array_like # does not run for a parameter annotated as Optional[ArrayLike] - return None if x is None else normalize_array_like(x, name) + return None if x is None else normalize_array_like(x, parm) -def normalize_seq_array_like(x, name=None): +def normalize_seq_array_like(x, parm=None): return tuple(normalize_array_like(value) for value in x) -def normalize_dtype(dtype, name=None): +def normalize_dtype(dtype, parm=None): # cf _decorators.dtype_to_torch torch_dtype = None if dtype is not None: @@ -60,12 +64,12 @@ def normalize_dtype(dtype, name=None): return torch_dtype -def normalize_subok_like(arg, name="subok"): - if arg: - raise ValueError(f"'{name}' parameter is not supported.") +def normalize_not_implemented(arg, parm): + if arg != parm.default: + raise NotImplementedError(f"'{parm.name}' parameter is not supported.") -def normalize_axis_like(arg, name=None): +def normalize_axis_like(arg, parm=None): from ._ndarray import ndarray if isinstance(arg, ndarray): @@ -73,7 +77,7 @@ def normalize_axis_like(arg, name=None): return arg -def normalize_ndarray(arg, name=None): +def normalize_ndarray(arg, parm=None): # check the arg is an ndarray, extract its tensor attribute if arg is None: return arg @@ -81,11 +85,11 @@ def normalize_ndarray(arg, name=None): from ._ndarray import ndarray if not isinstance(arg, ndarray): - raise TypeError(f"'{name}' must be an array") + raise TypeError(f"'{parm.name}' must be an array") return arg.tensor -def normalize_outarray(arg, name=None): +def normalize_outarray(arg, parm=None): # almost normalize_ndarray, only return the array, not its tensor if arg is None: return arg @@ -93,7 +97,7 @@ def normalize_outarray(arg, name=None): from ._ndarray import ndarray if not isinstance(arg, ndarray): - raise TypeError("'out' must be an array") + raise TypeError(f"'{parm.name}' must be an array") return arg @@ -105,15 +109,15 @@ def normalize_outarray(arg, name=None): "Optional[OutArray]": normalize_outarray, "NDArray": normalize_ndarray, "DTypeLike": normalize_dtype, - "SubokLike": normalize_subok_like, "AxisLike": normalize_axis_like, + "NotImplementedType": normalize_not_implemented, } def maybe_normalize(arg, parm): """Normalize arg if a normalizer is registred.""" normalizer = normalizers.get(parm.annotation, None) - return normalizer(arg, parm.name) if normalizer else arg + return normalizer(arg, parm) if normalizer else arg # ### Return value helpers ### diff --git a/torch_np/_ufuncs.py b/torch_np/_ufuncs.py index 516a31eb..0441c825 100644 --- a/torch_np/_ufuncs.py +++ b/torch_np/_ufuncs.py @@ -6,13 +6,16 @@ from . import _binary_ufuncs_impl, _helpers, _unary_ufuncs_impl from ._detail import _dtypes_impl, _util -from ._normalizations import ArrayLike, DTypeLike, OutArray, SubokLike, normalizer +from ._normalizations import ( + ArrayLike, + DTypeLike, + NotImplementedType, + OutArray, + normalizer, +) def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj): - if order != "K" or not where or signature or extobj: - raise NotImplementedError - if dtype is None: dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors]) @@ -44,6 +47,7 @@ def deco_binary_ufunc(torch_func): the pytorch functions for the actual work. """ + @normalizer def wrapped( x1: ArrayLike, x2: ArrayLike, @@ -54,7 +58,7 @@ def wrapped( casting="same_kind", order="K", dtype: DTypeLike = None, - subok: SubokLike = False, + subok: NotImplementedType = False, signature=None, extobj=None, ): @@ -85,20 +89,17 @@ def matmul( out: Optional[OutArray] = None, *, casting="same_kind", - order="K", + order: NotImplementedType = "K", dtype: DTypeLike = None, - subok: SubokLike = False, - signature=None, - extobj=None, - axes=None, - axis=None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, + axes: NotImplementedType = None, + axis: NotImplementedType = None, ): tensors = _ufunc_preprocess( (x1, x2), True, casting, order, dtype, subok, signature, extobj ) - if axis is not None or axes is not None: - raise NotImplementedError - result = _binary_ufuncs_impl.matmul(*tensors) result = _ufunc_postprocess(result, out, casting) @@ -117,13 +118,13 @@ def divmod( /, out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), *, - where=True, + where: NotImplementedType = True, casting="same_kind", - order="K", + order: NotImplementedType = "K", dtype: DTypeLike = None, - subok: SubokLike = False, - signature=None, - extobj=None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, ): # make sure we either have no out arrays at all, or there is either # out1, out2, or out=tuple, but not both @@ -151,13 +152,11 @@ def divmod( # -# For each torch ufunc implementation, decorate and attach the decorated name -# to this module. Its contents is then exported to the public namespace in __init__.py +# Attach ufuncs to this module, for a further export to the public namespace in __init__.py # for name in _binary: ufunc = getattr(_binary_ufuncs_impl, name) - decorated = normalizer(deco_binary_ufunc(ufunc)) - vars()[name] = decorated + vars()[name] = deco_binary_ufunc(ufunc) def modf(x, /, *args, **kwds): @@ -185,6 +184,7 @@ def deco_unary_ufunc(torch_func): the pytorch functions for the actual work. """ + @normalizer def wrapped( x: ArrayLike, /, @@ -194,7 +194,7 @@ def wrapped( casting="same_kind", order="K", dtype: DTypeLike = None, - subok: SubokLike = False, + subok: NotImplementedType = False, signature=None, extobj=None, ): @@ -212,13 +212,11 @@ def wrapped( # -# For each torch ufunc implementation, decorate and attach the decorated name -# to this module. Its contents is then exported to the public namespace in __init__.py +# Attach ufuncs to this module, for a further export to the public namespace in __init__.py # for name in _unary: ufunc = getattr(_unary_ufuncs_impl, name) - decorated = normalizer(deco_unary_ufunc(ufunc)) - vars()[name] = decorated + vars()[name] = deco_unary_ufunc(ufunc) __all__ = _binary + _unary diff --git a/torch_np/tests/test_basic.py b/torch_np/tests/test_basic.py index e77a8a07..66705597 100644 --- a/torch_np/tests/test_basic.py +++ b/torch_np/tests/test_basic.py @@ -500,3 +500,10 @@ def test_divmod_out_both_pos_and_kw(self): o = w.empty(1) with assert_raises(TypeError): w.divmod(1, 2, o, o, out=(o, o)) + + +class TestSmokeNotImpl: + def test_basic(self): + # smoke test that the "NotImplemented" annotation is picked up + with assert_raises(NotImplementedError): + w.empty(3, like="ooops")