diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index e3819ca4..5588c51c 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -43,6 +43,10 @@ jobs: # waiting on NumPy to allow/revert distinct NaNs for np.unique # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 array_api_tests/test_set_functions.py + + # missing copy arg + array_api_tests/test_signatures.py::test_func_signature[reshape] + # https://github.com/numpy/numpy/issues/21211 array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # https://github.com/numpy/numpy/issues/21213 diff --git a/array-api b/array-api index 2b9c402e..02fa9237 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit 2b9c402ebdb9825c2e8787caaabb5c5e3d9cf394 +Subproject commit 02fa9237eab3258120778baec12cd38cfd309ee3 diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index 3ec4e3c4..e83cd6ca 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -1,7 +1,7 @@ import os from importlib import import_module -from . import function_stubs +from . import stubs # Replace this with a specific array module to test it, for example, # @@ -53,38 +53,18 @@ def __repr__(self): __call__ = _raise __getattr__ = _raise -_integer_dtypes = [ - 'int8', - 'int16', - 'int32', - 'int64', - 'uint8', - 'uint16', - 'uint32', - 'uint64', -] - -_floating_dtypes = [ - 'float32', - 'float64', -] - -_numeric_dtypes = [ - *_integer_dtypes, - *_floating_dtypes, -] - -_boolean_dtypes = [ - 'bool', -] - _dtypes = [ - *_boolean_dtypes, - *_numeric_dtypes + "bool", + "uint8", "uint16", "uint32", "uint64", + "int8", "int16", "int32", "int64", + "float32", "float64", ] +_constants = ["e", "inf", "nan", "pi"] +_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] +_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS -for func_name in function_stubs.__all__ + _dtypes: +for attr in _top_level_attrs: try: - globals()[func_name] = getattr(mod, func_name) + globals()[attr] = getattr(mod, attr) except AttributeError: - globals()[func_name] = _UndefinedStub(func_name) + globals()[attr] = _UndefinedStub(attr) diff --git a/array_api_tests/function_stubs/__init__.py b/array_api_tests/function_stubs/__init__.py deleted file mode 100644 index 0b560378..00000000 --- a/array_api_tests/function_stubs/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Stub definitions for functions defined in the spec - -These are used to test function signatures. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -__all__ = [] - -from .array_object import __abs__, __add__, __and__, __array_namespace__, __bool__, __dlpack__, __dlpack_device__, __eq__, __float__, __floordiv__, __ge__, __getitem__, __gt__, __index__, __int__, __invert__, __le__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, to_device, __iadd__, __radd__, __iand__, __rand__, __ifloordiv__, __rfloordiv__, __ilshift__, __rlshift__, __imatmul__, __rmatmul__, __imod__, __rmod__, __imul__, __rmul__, __ior__, __ror__, __ipow__, __rpow__, __irshift__, __rrshift__, __isub__, __rsub__, __itruediv__, __rtruediv__, __ixor__, __rxor__, dtype, device, mT, ndim, shape, size, T - -__all__ += ['__abs__', '__add__', '__and__', '__array_namespace__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__index__', '__int__', '__invert__', '__le__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'to_device', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'mT', 'ndim', 'shape', 'size', 'T'] - -from .constants import e, inf, nan, pi - -__all__ += ['e', 'inf', 'nan', 'pi'] - -from .creation_functions import arange, asarray, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, tril, triu, zeros, zeros_like - -__all__ += ['arange', 'asarray', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'tril', 'triu', 'zeros', 'zeros_like'] - -from .data_type_functions import astype, broadcast_arrays, broadcast_to, can_cast, finfo, iinfo, result_type - -__all__ += ['astype', 'broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type'] - -from .elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logaddexp, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc - -__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] - -from .linear_algebra_functions import matmul, matrix_transpose, tensordot, vecdot - -__all__ += ['matmul', 'matrix_transpose', 'tensordot', 'vecdot'] - -from .manipulation_functions import concat, expand_dims, flip, permute_dims, reshape, roll, squeeze, stack - -__all__ += ['concat', 'expand_dims', 'flip', 'permute_dims', 'reshape', 'roll', 'squeeze', 'stack'] - -from .searching_functions import argmax, argmin, nonzero, where - -__all__ += ['argmax', 'argmin', 'nonzero', 'where'] - -from .set_functions import unique_all, unique_counts, unique_inverse, unique_values - -__all__ += ['unique_all', 'unique_counts', 'unique_inverse', 'unique_values'] - -from .sorting_functions import argsort, sort - -__all__ += ['argsort', 'sort'] - -from .statistical_functions import max, mean, min, prod, std, sum, var - -__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] - -from .utility_functions import all, any - -__all__ += ['all', 'any'] - -from . import linalg - -__all__ += ['linalg'] diff --git a/array_api_tests/function_stubs/_types.py b/array_api_tests/function_stubs/_types.py deleted file mode 100644 index 392a2c84..00000000 --- a/array_api_tests/function_stubs/_types.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -This file defines the types for type annotations. - -The type variables should be replaced with the actual types for a given -library, e.g., for NumPy TypeVar('array') would be replaced with ndarray. -""" - -from dataclasses import dataclass -from typing import Any, List, Literal, Optional, Sequence, Tuple, TypeVar, Union - -array = TypeVar('array') -device = TypeVar('device') -dtype = TypeVar('dtype') -SupportsDLPack = TypeVar('SupportsDLPack') -SupportsBufferProtocol = TypeVar('SupportsBufferProtocol') -PyCapsule = TypeVar('PyCapsule') -# ellipsis cannot actually be imported from anywhere, so include a dummy here -# to keep pyflakes happy. https://github.com/python/typeshed/issues/3556 -ellipsis = TypeVar('ellipsis') - -@dataclass -class finfo_object: - bits: int - eps: float - max: float - min: float - smallest_normal: float - -@dataclass -class iinfo_object: - bits: int - max: int - min: int - -# This should really be recursive, but that isn't supported yet. -NestedSequence = Sequence[Sequence[Any]] - -__all__ = ['Any', 'List', 'Literal', 'NestedSequence', 'Optional', -'PyCapsule', 'SupportsBufferProtocol', 'SupportsDLPack', 'Tuple', 'Union', -'array', 'device', 'dtype', 'ellipsis', 'finfo_object', 'iinfo_object'] - diff --git a/array_api_tests/function_stubs/array_object.py b/array_api_tests/function_stubs/array_object.py deleted file mode 100644 index ec4c33fd..00000000 --- a/array_api_tests/function_stubs/array_object.py +++ /dev/null @@ -1,391 +0,0 @@ -""" -Function stubs for array object. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/array_object.md -""" - -from __future__ import annotations - -from enum import IntEnum -from ._types import Any, Optional, PyCapsule, Tuple, Union, array, ellipsis - -def __abs__(self: array, /) -> array: - """ - Note: __abs__ is a method of the array object. - """ - pass - -def __add__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __add__ is a method of the array object. - """ - pass - -def __and__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __and__ is a method of the array object. - """ - pass - -def __array_namespace__(self: array, /, *, api_version: Optional[str] = None) -> object: - """ - Note: __array_namespace__ is a method of the array object. - """ - pass - -def __bool__(self: array, /) -> bool: - """ - Note: __bool__ is a method of the array object. - """ - pass - -def __dlpack__(self: array, /, *, stream: Optional[Union[int, Any]] = None) -> PyCapsule: - """ - Note: __dlpack__ is a method of the array object. - """ - pass - -def __dlpack_device__(self: array, /) -> Tuple[IntEnum, int]: - """ - Note: __dlpack_device__ is a method of the array object. - """ - pass - -def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: - """ - Note: __eq__ is a method of the array object. - """ - pass - -def __float__(self: array, /) -> float: - """ - Note: __float__ is a method of the array object. - """ - pass - -def __floordiv__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __floordiv__ is a method of the array object. - """ - pass - -def __ge__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __ge__ is a method of the array object. - """ - pass - -def __getitem__(self: array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array], /) -> array: - """ - Note: __getitem__ is a method of the array object. - """ - pass - -def __gt__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __gt__ is a method of the array object. - """ - pass - -def __index__(self: array, /) -> int: - """ - Note: __index__ is a method of the array object. - """ - pass - -def __int__(self: array, /) -> int: - """ - Note: __int__ is a method of the array object. - """ - pass - -def __invert__(self: array, /) -> array: - """ - Note: __invert__ is a method of the array object. - """ - pass - -def __le__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __le__ is a method of the array object. - """ - pass - -def __lshift__(self: array, other: Union[int, array], /) -> array: - """ - Note: __lshift__ is a method of the array object. - """ - pass - -def __lt__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __lt__ is a method of the array object. - """ - pass - -def __matmul__(self: array, other: array, /) -> array: - """ - Note: __matmul__ is a method of the array object. - """ - pass - -def __mod__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __mod__ is a method of the array object. - """ - pass - -def __mul__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __mul__ is a method of the array object. - """ - pass - -def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: - """ - Note: __ne__ is a method of the array object. - """ - pass - -def __neg__(self: array, /) -> array: - """ - Note: __neg__ is a method of the array object. - """ - pass - -def __or__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __or__ is a method of the array object. - """ - pass - -def __pos__(self: array, /) -> array: - """ - Note: __pos__ is a method of the array object. - """ - pass - -def __pow__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __pow__ is a method of the array object. - """ - pass - -def __rshift__(self: array, other: Union[int, array], /) -> array: - """ - Note: __rshift__ is a method of the array object. - """ - pass - -def __setitem__(self: array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array], value: Union[int, float, bool, array], /) -> None: - """ - Note: __setitem__ is a method of the array object. - """ - pass - -def __sub__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __sub__ is a method of the array object. - """ - pass - -def __truediv__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __truediv__ is a method of the array object. - """ - pass - -def __xor__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __xor__ is a method of the array object. - """ - pass - -def to_device(self: array, device: device, /, *, stream: Optional[Union[int, Any]] = None) -> array: - """ - Note: to_device is a method of the array object. - """ - pass - -def __iadd__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __iadd__ is a method of the array object. - """ - pass - -def __radd__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __radd__ is a method of the array object. - """ - pass - -def __iand__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __iand__ is a method of the array object. - """ - pass - -def __rand__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __rand__ is a method of the array object. - """ - pass - -def __ifloordiv__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __ifloordiv__ is a method of the array object. - """ - pass - -def __rfloordiv__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __rfloordiv__ is a method of the array object. - """ - pass - -def __ilshift__(self: array, other: Union[int, array], /) -> array: - """ - Note: __ilshift__ is a method of the array object. - """ - pass - -def __rlshift__(self: array, other: Union[int, array], /) -> array: - """ - Note: __rlshift__ is a method of the array object. - """ - pass - -def __imatmul__(self: array, other: array, /) -> array: - """ - Note: __imatmul__ is a method of the array object. - """ - pass - -def __rmatmul__(self: array, other: array, /) -> array: - """ - Note: __rmatmul__ is a method of the array object. - """ - pass - -def __imod__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __imod__ is a method of the array object. - """ - pass - -def __rmod__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __rmod__ is a method of the array object. - """ - pass - -def __imul__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __imul__ is a method of the array object. - """ - pass - -def __rmul__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __rmul__ is a method of the array object. - """ - pass - -def __ior__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __ior__ is a method of the array object. - """ - pass - -def __ror__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __ror__ is a method of the array object. - """ - pass - -def __ipow__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __ipow__ is a method of the array object. - """ - pass - -def __rpow__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __rpow__ is a method of the array object. - """ - pass - -def __irshift__(self: array, other: Union[int, array], /) -> array: - """ - Note: __irshift__ is a method of the array object. - """ - pass - -def __rrshift__(self: array, other: Union[int, array], /) -> array: - """ - Note: __rrshift__ is a method of the array object. - """ - pass - -def __isub__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __isub__ is a method of the array object. - """ - pass - -def __rsub__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __rsub__ is a method of the array object. - """ - pass - -def __itruediv__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __itruediv__ is a method of the array object. - """ - pass - -def __rtruediv__(self: array, other: Union[int, float, array], /) -> array: - """ - Note: __rtruediv__ is a method of the array object. - """ - pass - -def __ixor__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __ixor__ is a method of the array object. - """ - pass - -def __rxor__(self: array, other: Union[int, bool, array], /) -> array: - """ - Note: __rxor__ is a method of the array object. - """ - pass - -# Note: dtype is an attribute of the array object. -dtype: dtype = None - -# Note: device is an attribute of the array object. -device: device = None - -# Note: mT is an attribute of the array object. -mT: array = None - -# Note: ndim is an attribute of the array object. -ndim: int = None - -# Note: shape is an attribute of the array object. -shape: Tuple[Optional[int], ...] = None - -# Note: size is an attribute of the array object. -size: Optional[int] = None - -# Note: T is an attribute of the array object. -T: array = None - -__all__ = ['__abs__', '__add__', '__and__', '__array_namespace__', '__bool__', '__dlpack__', '__dlpack_device__', '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__index__', '__int__', '__invert__', '__le__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'to_device', '__iadd__', '__radd__', '__iand__', '__rand__', '__ifloordiv__', '__rfloordiv__', '__ilshift__', '__rlshift__', '__imatmul__', '__rmatmul__', '__imod__', '__rmod__', '__imul__', '__rmul__', '__ior__', '__ror__', '__ipow__', '__rpow__', '__irshift__', '__rrshift__', '__isub__', '__rsub__', '__itruediv__', '__rtruediv__', '__ixor__', '__rxor__', 'dtype', 'device', 'mT', 'ndim', 'shape', 'size', 'T'] diff --git a/array_api_tests/function_stubs/constants.py b/array_api_tests/function_stubs/constants.py deleted file mode 100644 index 602f0399..00000000 --- a/array_api_tests/function_stubs/constants.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Function stubs for constants. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/constants.md -""" - -from __future__ import annotations - - -e = None - -inf = None - -nan = None - -pi = None - -__all__ = ['e', 'inf', 'nan', 'pi'] diff --git a/array_api_tests/function_stubs/creation_functions.py b/array_api_tests/function_stubs/creation_functions.py deleted file mode 100644 index 09a2ba0d..00000000 --- a/array_api_tests/function_stubs/creation_functions.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Function stubs for creation functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/creation_functions.md -""" - -from __future__ import annotations - -from ._types import (List, NestedSequence, Optional, SupportsBufferProtocol, Tuple, Union, array, - device, dtype) - -def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def asarray(obj: Union[array, bool, int, float, NestedSequence[bool|int|float], SupportsBufferProtocol], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None, copy: Optional[bool] = None) -> array: - pass - -def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def empty_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: int = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def from_dlpack(x: object, /) -> array: - pass - -def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def full_like(x: array, /, fill_value: Union[int, float], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: bool = True) -> array: - pass - -def meshgrid(*arrays: array, indexing: str = 'xy') -> List[array, ...]: - pass - -def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def ones_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def tril(x: array, /, *, k: int = 0) -> array: - pass - -def triu(x: array, /, *, k: int = 0) -> array: - pass - -def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -def zeros_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: - pass - -__all__ = ['arange', 'asarray', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'tril', 'triu', 'zeros', 'zeros_like'] diff --git a/array_api_tests/function_stubs/data_type_functions.py b/array_api_tests/function_stubs/data_type_functions.py deleted file mode 100644 index a5cd4285..00000000 --- a/array_api_tests/function_stubs/data_type_functions.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Function stubs for data type functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/data_type_functions.md -""" - -from __future__ import annotations - -from ._types import List, Tuple, Union, array, dtype, finfo_object, iinfo_object - -def astype(x: array, dtype: dtype, /, *, copy: bool = True) -> array: - pass - -def broadcast_arrays(*arrays: array) -> List[array]: - pass - -def broadcast_to(x: array, /, shape: Tuple[int, ...]) -> array: - pass - -def can_cast(from_: Union[dtype, array], to: dtype, /) -> bool: - pass - -def finfo(type: Union[dtype, array], /) -> finfo_object: - pass - -def iinfo(type: Union[dtype, array], /) -> iinfo_object: - pass - -def result_type(*arrays_and_dtypes: Union[array, dtype]) -> dtype: - pass - -__all__ = ['astype', 'broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type'] diff --git a/array_api_tests/function_stubs/elementwise_functions.py b/array_api_tests/function_stubs/elementwise_functions.py deleted file mode 100644 index c6efd7da..00000000 --- a/array_api_tests/function_stubs/elementwise_functions.py +++ /dev/null @@ -1,183 +0,0 @@ -""" -Function stubs for elementwise functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/elementwise_functions.md -""" - -from __future__ import annotations - -from ._types import array - -def abs(x: array, /) -> array: - pass - -def acos(x: array, /) -> array: - pass - -def acosh(x: array, /) -> array: - pass - -def add(x1: array, x2: array, /) -> array: - pass - -def asin(x: array, /) -> array: - pass - -def asinh(x: array, /) -> array: - pass - -def atan(x: array, /) -> array: - pass - -def atan2(x1: array, x2: array, /) -> array: - pass - -def atanh(x: array, /) -> array: - pass - -def bitwise_and(x1: array, x2: array, /) -> array: - pass - -def bitwise_left_shift(x1: array, x2: array, /) -> array: - pass - -def bitwise_invert(x: array, /) -> array: - pass - -def bitwise_or(x1: array, x2: array, /) -> array: - pass - -def bitwise_right_shift(x1: array, x2: array, /) -> array: - pass - -def bitwise_xor(x1: array, x2: array, /) -> array: - pass - -def ceil(x: array, /) -> array: - pass - -def cos(x: array, /) -> array: - pass - -def cosh(x: array, /) -> array: - pass - -def divide(x1: array, x2: array, /) -> array: - pass - -def equal(x1: array, x2: array, /) -> array: - pass - -def exp(x: array, /) -> array: - pass - -def expm1(x: array, /) -> array: - pass - -def floor(x: array, /) -> array: - pass - -def floor_divide(x1: array, x2: array, /) -> array: - pass - -def greater(x1: array, x2: array, /) -> array: - pass - -def greater_equal(x1: array, x2: array, /) -> array: - pass - -def isfinite(x: array, /) -> array: - pass - -def isinf(x: array, /) -> array: - pass - -def isnan(x: array, /) -> array: - pass - -def less(x1: array, x2: array, /) -> array: - pass - -def less_equal(x1: array, x2: array, /) -> array: - pass - -def log(x: array, /) -> array: - pass - -def log1p(x: array, /) -> array: - pass - -def log2(x: array, /) -> array: - pass - -def log10(x: array, /) -> array: - pass - -def logaddexp(x1: array, x2: array) -> array: - pass - -def logical_and(x1: array, x2: array, /) -> array: - pass - -def logical_not(x: array, /) -> array: - pass - -def logical_or(x1: array, x2: array, /) -> array: - pass - -def logical_xor(x1: array, x2: array, /) -> array: - pass - -def multiply(x1: array, x2: array, /) -> array: - pass - -def negative(x: array, /) -> array: - pass - -def not_equal(x1: array, x2: array, /) -> array: - pass - -def positive(x: array, /) -> array: - pass - -def pow(x1: array, x2: array, /) -> array: - pass - -def remainder(x1: array, x2: array, /) -> array: - pass - -def round(x: array, /) -> array: - pass - -def sign(x: array, /) -> array: - pass - -def sin(x: array, /) -> array: - pass - -def sinh(x: array, /) -> array: - pass - -def square(x: array, /) -> array: - pass - -def sqrt(x: array, /) -> array: - pass - -def subtract(x1: array, x2: array, /) -> array: - pass - -def tan(x: array, /) -> array: - pass - -def tanh(x: array, /) -> array: - pass - -def trunc(x: array, /) -> array: - pass - -__all__ = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] diff --git a/array_api_tests/function_stubs/linalg.py b/array_api_tests/function_stubs/linalg.py deleted file mode 100644 index 07ccde34..00000000 --- a/array_api_tests/function_stubs/linalg.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Function stubs for linear algebra functions (Extension). - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/linear_algebra_functions.md -""" - -from __future__ import annotations - -from ._types import Literal, Optional, Tuple, Union, array -from .constants import inf -from collections.abc import Sequence - -def cholesky(x: array, /, *, upper: bool = False) -> array: - pass - -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: - pass - -def det(x: array, /) -> array: - pass - -def diagonal(x: array, /, *, offset: int = 0) -> array: - pass - -def eigh(x: array, /) -> Tuple[array]: - pass - -def eigvalsh(x: array, /) -> array: - pass - -def inv(x: array, /) -> array: - pass - -def matmul(x1: array, x2: array, /) -> array: - pass - -def matrix_norm(x: array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf, 'fro', 'nuc']]] = 'fro') -> array: - pass - -def matrix_power(x: array, n: int, /) -> array: - pass - -def matrix_rank(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array: - pass - -def matrix_transpose(x: array, /) -> array: - pass - -def outer(x1: array, x2: array, /) -> array: - pass - -def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array: - pass - -def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tuple[array, array]: - pass - -def slogdet(x: array, /) -> Tuple[array, array]: - pass - -def solve(x1: array, x2: array, /) -> array: - pass - -def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]: - pass - -def svdvals(x: array, /) -> array: - pass - -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array: - pass - -def trace(x: array, /, *, offset: int = 0) -> array: - pass - -def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array: - pass - -def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Union[int, float, Literal[inf, -inf]] = 2) -> array: - pass - -__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] diff --git a/array_api_tests/function_stubs/linear_algebra_functions.py b/array_api_tests/function_stubs/linear_algebra_functions.py deleted file mode 100644 index 29eebdda..00000000 --- a/array_api_tests/function_stubs/linear_algebra_functions.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -Function stubs for linear algebra functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/linear_algebra_functions.md -""" - -from __future__ import annotations - -from ._types import Tuple, Union, array -from collections.abc import Sequence - -def matmul(x1: array, x2: array, /) -> array: - pass - -def matrix_transpose(x: array, /) -> array: - pass - -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array: - pass - -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: - pass - -__all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot'] diff --git a/array_api_tests/function_stubs/manipulation_functions.py b/array_api_tests/function_stubs/manipulation_functions.py deleted file mode 100644 index b5b921f1..00000000 --- a/array_api_tests/function_stubs/manipulation_functions.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -Function stubs for manipulation functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/manipulation_functions.md -""" - -from __future__ import annotations - -from ._types import List, Optional, Tuple, Union, array - -def concat(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: Optional[int] = 0) -> array: - pass - -def expand_dims(x: array, /, *, axis: int) -> array: - pass - -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: - pass - -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: - pass - -def reshape(x: array, /, shape: Tuple[int, ...]) -> array: - pass - -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: - pass - -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: - pass - -def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> array: - pass - -__all__ = ['concat', 'expand_dims', 'flip', 'permute_dims', 'reshape', 'roll', 'squeeze', 'stack'] diff --git a/array_api_tests/function_stubs/searching_functions.py b/array_api_tests/function_stubs/searching_functions.py deleted file mode 100644 index 283ac74c..00000000 --- a/array_api_tests/function_stubs/searching_functions.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Function stubs for searching functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/searching_functions.md -""" - -from __future__ import annotations - -from ._types import Optional, Tuple, array - -def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: - pass - -def argmin(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: - pass - -def nonzero(x: array, /) -> Tuple[array, ...]: - pass - -def where(condition: array, x1: array, x2: array, /) -> array: - pass - -__all__ = ['argmax', 'argmin', 'nonzero', 'where'] diff --git a/array_api_tests/function_stubs/set_functions.py b/array_api_tests/function_stubs/set_functions.py deleted file mode 100644 index efc1ff52..00000000 --- a/array_api_tests/function_stubs/set_functions.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Function stubs for set functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/set_functions.md -""" - -from __future__ import annotations - -from ._types import Tuple, array - -def unique_all(x: array, /) -> Tuple[array, array, array, array]: - pass - -def unique_counts(x: array, /) -> Tuple[array, array]: - pass - -def unique_inverse(x: array, /) -> Tuple[array, array]: - pass - -def unique_values(x: array, /) -> array: - pass - -__all__ = ['unique_all', 'unique_counts', 'unique_inverse', 'unique_values'] diff --git a/array_api_tests/function_stubs/sorting_functions.py b/array_api_tests/function_stubs/sorting_functions.py deleted file mode 100644 index 2040de54..00000000 --- a/array_api_tests/function_stubs/sorting_functions.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Function stubs for sorting functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/sorting_functions.md -""" - -from __future__ import annotations - -from ._types import array - -def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array: - pass - -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array: - pass - -__all__ = ['argsort', 'sort'] diff --git a/array_api_tests/function_stubs/statistical_functions.py b/array_api_tests/function_stubs/statistical_functions.py deleted file mode 100644 index fa62f710..00000000 --- a/array_api_tests/function_stubs/statistical_functions.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Function stubs for statistical functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/statistical_functions.md -""" - -from __future__ import annotations - -from ._types import Optional, Tuple, Union, array, dtype - -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - pass - -def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - pass - -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - pass - -def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array: - pass - -def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: - pass - -def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array: - pass - -def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: - pass - -__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] diff --git a/array_api_tests/function_stubs/utility_functions.py b/array_api_tests/function_stubs/utility_functions.py deleted file mode 100644 index ae427401..00000000 --- a/array_api_tests/function_stubs/utility_functions.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Function stubs for utility functions. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/utility_functions.md -""" - -from __future__ import annotations - -from ._types import Optional, Tuple, Union, array - -def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - pass - -def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - pass - -__all__ = ['all', 'any'] diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 7e436082..20cc0e03 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -16,7 +16,7 @@ from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype from ._array_module import broadcast_to, eye, float32, float64, full -from .function_stubs import elementwise_functions +from .stubs import category_to_funcs from .pytest_helpers import nargs from .typing import Array, DataType, Shape @@ -110,7 +110,7 @@ def mutually_promotable_dtypes( # will both correspond to the same function. # TODO: Extend this to all functions, not just elementwise -elementwise_functions_names = shared(sampled_from(elementwise_functions.__all__)) +elementwise_functions_names = shared(sampled_from([f.__name__ for f in category_to_funcs["elementwise"]])) array_functions_names = elementwise_functions_names multiarg_array_functions_names = array_functions_names.filter( lambda func_name: nargs(func_name) > 1) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 588cfb1b..1fe3ca66 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -13,15 +13,6 @@ oneway_broadcastable_shapes, oneway_promotable_dtypes, ) -from ..test_signatures import extension_module - - -def test_extension_module_is_extension(): - assert extension_module("linalg") - - -def test_extension_func_is_not_extension(): - assert not extension_module("linalg.cross") @pytest.mark.parametrize( diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 989b486f..5a96b27f 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -5,8 +5,8 @@ from . import _array_module as xp from . import array_helpers as ah from . import dtype_helpers as dh -from . import function_stubs from . import shape_helpers as sh +from . import stubs from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ @@ -65,7 +65,7 @@ def doesnt_raise(function, message=""): def nargs(func_name): - return len(getfullargspec(getattr(function_stubs, func_name)).args) + return len(getfullargspec(stubs.name_to_func[func_name]).args) def fmt_kw(kw: Dict[str, Any]) -> str: diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 15fb7646..1ff1e1b6 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -1,3 +1,4 @@ +import inspect import sys from importlib import import_module from importlib.util import find_spec @@ -5,7 +6,13 @@ from types import FunctionType, ModuleType from typing import Dict, List -__all__ = ["category_to_funcs", "array", "extension_to_funcs"] +__all__ = [ + "name_to_func", + "array_methods", + "category_to_funcs", + "EXTENSIONS", + "extension_to_funcs", +] spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification" @@ -22,6 +29,11 @@ name = path.name.replace(".py", "") name_to_mod[name] = import_module(f"signatures.{name}") +array = name_to_mod["array_object"].array +array_methods = [ + f for n, f in inspect.getmembers(array, predicate=inspect.isfunction) + if n != "__init__" # probably exists for Sphinx +] category_to_funcs: Dict[str, List[FunctionType]] = {} for name, mod in name_to_mod.items(): @@ -31,14 +43,15 @@ assert all(isinstance(o, FunctionType) for o in objects) category_to_funcs[category] = objects - -array = name_to_mod["array_object"].array - - -EXTENSIONS = ["linalg"] +EXTENSIONS: str = ["linalg"] extension_to_funcs: Dict[str, List[FunctionType]] = {} for ext in EXTENSIONS: mod = name_to_mod[ext] objects = [getattr(mod, name) for name in mod.__all__] assert all(isinstance(o, FunctionType) for o in objects) extension_to_funcs[ext] = objects + +all_funcs = [] +for funcs in [array_methods, *category_to_funcs.values(), *extension_to_funcs.values()]: + all_funcs.extend(funcs) +name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index c4d24ff2..2e197ee9 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -1,4 +1,5 @@ import inspect +from itertools import chain import pytest @@ -6,44 +7,38 @@ from .pytest_helpers import raises, doesnt_raise from . import dtype_helpers as dh -from . import function_stubs +from . import stubs -submodules = [m for m in dir(function_stubs) if - inspect.ismodule(getattr(function_stubs, m)) and not - m.startswith('_')] +def extension_module(name) -> bool: + for funcs in stubs.extension_to_funcs.values(): + for func in funcs: + if name == func.__name__: + return True + else: + return False -def stub_module(name): - for m in submodules: - if name in getattr(function_stubs, m).__all__: - return m -def extension_module(name): - return name in submodules and name in function_stubs.__all__ +params = [] +for name in [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]: + if name in ["where", "expand_dims", "reshape"]: + params.append(pytest.param(name, marks=pytest.mark.skip(reason="faulty test"))) + else: + params.append(name) -extension_module_names = [] -for n in function_stubs.__all__: - if extension_module(n): - extension_module_names.extend([f'{n}.{i}' for i in getattr(function_stubs, n).__all__]) - -params = [] -for name in function_stubs.__all__: - marks = [] - if extension_module(name): - marks.append(pytest.mark.xp_extension(name)) - params.append(pytest.param(name, marks=marks)) -for name in extension_module_names: - ext = name.split('.')[0] - mark = pytest.mark.xp_extension(ext) - params.append(pytest.param(name, marks=[mark])) +for ext, name in [(ext, f.__name__) for ext, funcs in stubs.extension_to_funcs.items() for f in funcs]: + params.append(pytest.param(name, marks=pytest.mark.xp_extension(ext))) -def array_method(name): - return stub_module(name) == 'array_object' +def array_method(name) -> bool: + return name in [f.__name__ for f in stubs.array_methods] -def function_category(name): - return stub_module(name).rsplit('_', 1)[0].replace('_', ' ') +def function_category(name) -> str: + for category, funcs in chain(stubs.category_to_funcs.items(), stubs.extension_to_funcs.items()): + for func in funcs: + if name == func.__name__: + return category def example_argument(arg, func_name, dtype): """ @@ -141,7 +136,7 @@ def example_argument(arg, func_name, dtype): return ones((3,), dtype=dtype) # Linear algebra functions tend to error if the input isn't "nice" as # a matrix - elif arg.startswith('x') and func_name in function_stubs.linalg.__all__: + elif arg.startswith('x') and func_name in [f.__name__ for f in stubs.extension_to_funcs["linalg"]]: return eye(3) return known_args[arg] else: @@ -150,13 +145,15 @@ def example_argument(arg, func_name, dtype): @pytest.mark.parametrize('name', params) def test_has_names(name): if extension_module(name): - assert hasattr(mod, name), f'{mod_name} is missing the {name} extension' - elif '.' in name: - extension_mod, name = name.split('.') - assert hasattr(getattr(mod, extension_mod), name), f"{mod_name} is missing the {function_category(name)} extension function {name}()" + ext = next( + ext for ext, funcs in stubs.extension_to_funcs.items() + if name in [f.__name__ for f in funcs] + ) + ext_mod = getattr(mod, ext) + assert hasattr(ext_mod, name), f"{mod_name} is missing the {function_category(name)} extension function {name}()" elif array_method(name): arr = ones((1, 1)) - if getattr(function_stubs.array_object, name) is None: + if name not in [f.__name__ for f in stubs.array_methods]: assert hasattr(arr, name), f"The array object is missing the attribute {name}" else: assert hasattr(arr, name), f"The array object is missing the method {name}()" @@ -195,14 +192,12 @@ def test_function_positional_args(name): _mod = ones((), dtype=float64) else: _mod = example_argument('self', name, dtype) - stub_func = getattr(function_stubs, name) elif '.' in name: extension_module_name, name = name.split('.') _mod = getattr(mod, extension_module_name) - stub_func = getattr(getattr(function_stubs, extension_module_name), name) else: _mod = mod - stub_func = getattr(function_stubs, name) + stub_func = stubs.name_to_func[name] if not hasattr(_mod, name): pytest.skip(f"{mod_name} does not have {name}(), skipping.") @@ -248,14 +243,12 @@ def test_function_keyword_only_args(name): if array_method(name): _mod = ones((1, 1)) - stub_func = getattr(function_stubs, name) elif '.' in name: extension_module_name, name = name.split('.') _mod = getattr(mod, extension_module_name) - stub_func = getattr(getattr(function_stubs, extension_module_name), name) else: _mod = mod - stub_func = getattr(function_stubs, name) + stub_func = stubs.name_to_func[name] if not hasattr(_mod, name): pytest.skip(f"{mod_name} does not have {name}(), skipping.") diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 575e9011..9bbaf930 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -13,7 +13,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .function_stubs import elementwise_functions +from .stubs import category_to_funcs from .typing import DataType, Param, ScalarType bitwise_shift_funcs = [ @@ -52,7 +52,7 @@ def mark_stubbed_dtypes(*dtypes): func_params: List[Param[str, Tuple[DataType, ...], DataType]] = [] -for func_name in elementwise_functions.__all__: +for func_name in [f.__name__ for f in category_to_funcs["elementwise"]]: valid_in_dtypes = dh.func_in_dtypes[func_name] ndtypes = ph.nargs(func_name) if ndtypes == 1: diff --git a/generate_stubs.py b/generate_stubs.py deleted file mode 100755 index 3dcde542..00000000 --- a/generate_stubs.py +++ /dev/null @@ -1,943 +0,0 @@ -#!/usr/bin/env python -""" -Generate stub files for the tests. - -To run the script, first clone the https://github.com/data-apis/array-api -repo, then run - -./generate_stubs.py path/to/clone/of/array-api - -This will update the stub files in array_api_tests/function_stubs/ -""" -from __future__ import annotations - -import argparse -import os -import ast -import itertools -from collections import defaultdict -from typing import DefaultDict, Dict, List -from pathlib import Path - -import regex -from removestar.removestar import fix_code - -FUNCTION_HEADER_RE = regex.compile(r'\(function-(.*?)\)') -METHOD_HEADER_RE = regex.compile(r'\(method-(.*?)\)') -HEADER_RE = regex.compile(r'\((?:function-linalg|function|method|constant|attribute)-(.*?)\)') -FUNCTION_RE = regex.compile(r'\(function-.*\)=\n#+ ?(.*\(.*\))') -METHOD_RE = regex.compile(r'\(method-.*\)=\n#+ ?(.*\(.*\))') -CONSTANT_RE = regex.compile(r'\(constant-.*\)=\n#+ ?(.*)') -ATTRIBUTE_RE = regex.compile(r'\(attribute-.*\)=\n#+ ?(.*)') -IN_PLACE_OPERATOR_RE = regex.compile(r'- +`.*`. May be implemented via `__i(.*)__`.') -REFLECTED_OPERATOR_RE = regex.compile(r'- +`__r(.*)__`') -ALIAS_RE = regex.compile(r'Alias for {ref}`function-(.*)`.') - -OPS = [ - '__abs__', - '__add__', - '__and__', - '__bool__', - '__eq__', - '__float__', - '__floordiv__', - '__ge__', - '__getitem__', - '__gt__', - '__invert__', - '__le__', - '__lshift__', - '__lt__', - '__matmul__', - '__mod__', - '__mul__', - '__ne__', - '__neg__', - '__or__', - '__pos__', - '__pow__', - '__rshift__', - '__sub__', - '__truediv__', - '__xor__' -] -IOPS = [ - '__iadd__', - '__isub__', - '__imul__', - '__itruediv__', - '__ifloordiv__', - '__ipow__', - '__imod__', - '__imatmul__', - '__iand__', - '__ior__', - '__ixor__', - '__ilshift__', - '__irshift__' -] - -NAME_RE = regex.compile(r'(.*?)\(.*\)') - -STUB_FILE_HEADER = '''\ -""" -Function stubs for {title}. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. - -See -https://github.com/data-apis/array-api/blob/master/spec/API_specification/{filename} -""" - -from __future__ import annotations - -from enum import * -from ._types import * -from .constants import * -from collections.abc import * -''' -# ^ Constants are used in some of the type annotations - -INIT_HEADER = '''\ -""" -Stub definitions for functions defined in the spec - -These are used to test function signatures. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -__all__ = [] -''' - -SPECIAL_CASES_HEADER = '''\ -""" -Special cases tests for {func}. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import * -from ..hypothesis_helpers import numeric_arrays -from .._array_module import {func} - -from hypothesis import given - -''' - -OP_SPECIAL_CASES_HEADER = '''\ -""" -Special cases tests for {func}. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import * -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - -''' - - -IOP_SPECIAL_CASES_HEADER = '''\ -""" -Special cases tests for {func}. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from operator import {operator} - -from ..array_helpers import * -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - -''' - - -TYPES_HEADER = '''\ -""" -This file defines the types for type annotations. - -The type variables should be replaced with the actual types for a given -library, e.g., for NumPy TypeVar('array') would be replaced with ndarray. -""" - -from dataclasses import dataclass -from typing import Any, List, Literal, Optional, Sequence, Tuple, TypeVar, Union - -array = TypeVar('array') -device = TypeVar('device') -dtype = TypeVar('dtype') -SupportsDLPack = TypeVar('SupportsDLPack') -SupportsBufferProtocol = TypeVar('SupportsBufferProtocol') -PyCapsule = TypeVar('PyCapsule') -# ellipsis cannot actually be imported from anywhere, so include a dummy here -# to keep pyflakes happy. https://github.com/python/typeshed/issues/3556 -ellipsis = TypeVar('ellipsis') - -@dataclass -class finfo_object: - bits: int - eps: float - max: float - min: float - smallest_normal: float - -@dataclass -class iinfo_object: - bits: int - max: int - min: int - -# This should really be recursive, but that isn't supported yet. -NestedSequence = Sequence[Sequence[Any]] - -__all__ = ['Any', 'List', 'Literal', 'NestedSequence', 'Optional', -'PyCapsule', 'SupportsBufferProtocol', 'SupportsDLPack', 'Tuple', 'Union', -'array', 'device', 'dtype', 'ellipsis', 'finfo_object', 'iinfo_object'] - -''' -def main(): - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument('array_api_repo', help="Path to clone of the array-api repository") - parser.add_argument('--no-write', help="Print what it would do but don't write any files", action='store_false', dest='write') - parser.add_argument('-v', '--verbose', help="Print verbose output to the terminal", action='store_true') - args = parser.parse_args() - - types_path = os.path.join('array_api_tests', 'function_stubs', '_types.py') - if args.write: - with open(types_path, 'w') as f: - f.write(TYPES_HEADER) - - special_cases_dir = Path('array_api_tests/special_cases') - special_cases_dir.mkdir(exist_ok=True) - (special_cases_dir / '__init__.py').touch() - - spec_dir = os.path.join(args.array_api_repo, 'spec', 'API_specification') - extensions_dir = os.path.join(args.array_api_repo, 'spec', 'extensions') - files = sorted([os.path.join(spec_dir, f) for f in os.listdir(spec_dir)] - + [os.path.join(extensions_dir, f) for f in os.listdir(extensions_dir)]) - modules = {} - all_annotations = {} - for file in files: - filename = os.path.basename(file) - with open(file) as f: - text = f.read() - functions = FUNCTION_RE.findall(text) - methods = METHOD_RE.findall(text) - constants = CONSTANT_RE.findall(text) - attributes = ATTRIBUTE_RE.findall(text) - if not (functions or methods or constants or attributes): - continue - if args.verbose: - print(f"Found signatures in {filename}") - - title = filename.replace('.md', '').replace('_', ' ') - if 'extensions' in file: - if filename == 'index.md': - continue - elif filename != 'linear_algebra_functions.md': - raise RuntimeError(f"Don't know how to handle extension file {filename}") - py_file = 'linalg.py' - title += " (Extension)" - else: - py_file = filename.replace('.md', '.py') - py_path = os.path.join('array_api_tests', 'function_stubs', py_file) - module_name = py_file.replace('.py', '') - modules[module_name] = [] - if args.verbose: - print(f"Writing {py_path}") - - annotations = parse_annotations(text, all_annotations, verbose=args.verbose) - all_annotations.update(annotations) - - if filename == 'array_object.md': - in_place_operators = IN_PLACE_OPERATOR_RE.findall(text) - reflected_operators = REFLECTED_OPERATOR_RE.findall(text) - if sorted(in_place_operators) != sorted(reflected_operators): - raise RuntimeError(f"Unexpected in-place or reflected operator(s): {set(in_place_operators).symmetric_difference(set(reflected_operators))}") - - sigs = {} - code = "" - code += STUB_FILE_HEADER.format(filename=filename, title=title) - for sig in itertools.chain(functions, methods): - ismethod = sig in methods - sig = sig.replace(r'\_', '_') - func_name = NAME_RE.match(sig).group(1) - if '.' in func_name: - mod, func_name = func_name.split('.', 2) - if mod != 'linalg': - raise RuntimeError(f"Unexpected namespace prefix {mod!r}") - sig = sig.replace(mod + '.', '') - doc = "" - if ismethod: - doc = f''' - """ - Note: {func_name} is a method of the array object. - """''' - if func_name not in annotations: - print(f"Warning: No annotations found for {func_name}") - annotated_sig = sig - else: - annotated_sig = add_annotation(sig, annotations[func_name]) - if args.verbose: - print(f"Writing stub for {annotated_sig}") - code += f""" -def {annotated_sig}:{doc} - pass -""" - modules[module_name].append(func_name) - sigs[func_name] = sig - - if (filename == 'array_object.md' and func_name.startswith('__') - and (op := func_name[2:-2]) in in_place_operators): - normal_op = func_name - iop = f'__i{op}__' - rop = f'__r{op}__' - for func_name in [iop, rop]: - methods.append(sigs[normal_op].replace(normal_op, func_name)) - annotation = annotations[normal_op].copy() - for k, v in annotation.items(): - annotation[k] = v.replace(normal_op, func_name) - annotations[func_name] = annotation - - for const in constants: - if args.verbose: - print(f"Writing stub for {const}") - code += f"\n{const} = None\n" - modules[module_name].append(const) - - for attr in attributes: - annotation = annotations[attr]['return'] - code += f"\n# Note: {attr} is an attribute of the array object." - code += f"\n{attr}: {annotation} = None\n" - modules[module_name].append(attr) - - code += '\n__all__ = [' - code += ', '.join(f"'{i}'" for i in modules[module_name]) - code += ']\n' - - if args.write: - with open(py_path, 'w') as f: - f.write(code) - code = fix_code(code, file=py_path, verbose=False, quiet=False) - if args.write: - with open(py_path, 'w') as f: - f.write(code) - - if filename == 'elementwise_functions.md': - special_cases = parse_special_cases(text, verbose=args.verbose) - for func in special_cases: - py_path = os.path.join('array_api_tests', 'special_cases', f'test_{func}.py') - tests = make_special_case_tests(func, special_cases, sigs) - if tests: - code = SPECIAL_CASES_HEADER.format(func=func) + '\n'.join(tests) - # quiet=False will make it print a warning if a name is not found (indicating an error) - code = fix_code(code, file=py_path, verbose=False, quiet=False) - if args.write: - with open(py_path, 'w') as f: - f.write(code) - elif filename == 'array_object.md': - op_special_cases = parse_special_cases(text, verbose=args.verbose) - for func in op_special_cases: - py_path = os.path.join('array_api_tests', 'special_cases', f'test_dunder_{func[2:-2]}.py') - tests = make_special_case_tests(func, op_special_cases, sigs) - if tests: - code = OP_SPECIAL_CASES_HEADER.format(func=func) + '\n'.join(tests) - code = fix_code(code, file=py_path, verbose=False, quiet=False) - if args.write: - with open(py_path, 'w') as f: - f.write(code) - iop_special_cases = {} - for name in IN_PLACE_OPERATOR_RE.findall(text): - op = f"__{name}__" - iop = f"__i{name}__" - iop_special_cases[iop] = op_special_cases[op] - for func in iop_special_cases: - py_path = os.path.join('array_api_tests', 'special_cases', f'test_dunder_{func[2:-2]}.py') - tests = make_special_case_tests(func, iop_special_cases, sigs) - if tests: - code = IOP_SPECIAL_CASES_HEADER.format(func=func, operator=func[2:-2]) + '\n'.join(tests) - code = fix_code(code, file=py_path, verbose=False, quiet=False) - if args.write: - with open(py_path, 'w') as f: - f.write(code) - - init_path = os.path.join('array_api_tests', 'function_stubs', '__init__.py') - if args.write: - with open(init_path, 'w') as f: - f.write(INIT_HEADER) - for module_name in modules: - if module_name == 'linalg': - f.write(f'\nfrom . import {module_name}\n') - f.write("\n__all__ += ['linalg']\n") - continue - f.write(f"\nfrom .{module_name} import ") - f.write(', '.join(modules[module_name])) - f.write('\n\n') - f.write('__all__ += [') - f.write(', '.join(f"'{i}'" for i in modules[module_name])) - f.write(']\n') - -# (?|...) is a branch reset (regex module only feature). It works like (?:...) -# except only the matched alternative is assigned group numbers, so \1, \2, and -# so on will always refer to a single match from _value. -_value = r"(?|`([^`]*)`|a (finite) number|a (positive \(i\.e\., greater than `0`\) finite) number|a (negative \(i\.e\., less than `0`\) finite) number|(finite)|(positive)|(negative)|(nonzero)|(?:a )?(nonzero finite) numbers?|an (integer) value|already (integer)-valued|an (odd integer) value|(even integer closest to `x_i`)|an implementation-dependent approximation to `([^`]*)`(?: \(rounded\))?|a (signed (?:infinity|zero)) with the mathematical sign determined by the rule already stated above|(positive mathematical sign)|(negative mathematical sign))" -SPECIAL_CASE_REGEXS = dict( - ONE_ARG_EQUAL = regex.compile(rf'^- +If `x_i` is {_value}, the result is {_value}\.$'), - ONE_ARG_GREATER = regex.compile(rf'^- +If `x_i` is greater than {_value}, the result is {_value}\.$'), - ONE_ARG_LESS = regex.compile(rf'^- +If `x_i` is less than {_value}, the result is {_value}\.$'), - ONE_ARG_EITHER = regex.compile(rf'^- +If `x_i` is either {_value} or {_value}, the result is {_value}\.$'), - ONE_ARG_TWO_INTEGERS_EQUALLY_CLOSE = regex.compile(rf'^- +If two integers are equally close to `x_i`, the result is the {_value}\.$'), - - TWO_ARGS_EQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_GREATER__EQUAL = regex.compile(rf'^- +If `x1_i` is greater than {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_GREATER_EQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is greater than {_value}, `x1_i` is {_value}, and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_LESS__EQUAL = regex.compile(rf'^- +If `x1_i` is less than {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_LESS_EQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is less than {_value}, `x1_i` is {_value}, and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_LESS_EQUAL__EQUAL_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is less than {_value}, `x1_i` is {_value}, `x2_i` is {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__GREATER = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is greater than {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__LESS = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is less than {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is not (?:equal to )?{_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__LESS_EQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is less than {_value}, and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__LESS_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is less than {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__GREATER_EQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is greater than {_value}, and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__GREATER_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is greater than {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'), - TWO_ARGS_NOTEQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is not (?:equal to )?{_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_ABSEQUAL__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_ABSGREATER__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is greater than {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_ABSLESS__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is less than {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_EITHER = regex.compile(rf'^- +If either `x1_i` or `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_EITHER__EQUAL = regex.compile(rf'^- +If `x1_i` is either {_value} or {_value} and `x2_i` is {_value}, the result is {_value}\.$'), - TWO_ARGS_EQUAL__EITHER = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is either {_value} or {_value}, the result is {_value}\.$'), - TWO_ARGS_EITHER__EITHER = regex.compile(rf'^- +If `x1_i` is either {_value} or {_value} and `x2_i` is either {_value} or {_value}, the result is {_value}\.$'), - TWO_ARGS_SAME_SIGN = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign, the result has a {_value}\.$'), - TWO_ARGS_SAME_SIGN_EXCEPT = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign, the result has a {_value}, unless the result is {_value}\. If the result is {_value}, the "sign" of {_value} is implementation-defined\.$'), - TWO_ARGS_SAME_SIGN_BOTH = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign and are both {_value}, the result has a {_value}\.$'), - TWO_ARGS_DIFFERENT_SIGNS = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs, the result has a {_value}\.$'), - TWO_ARGS_DIFFERENT_SIGNS_EXCEPT = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs, the result has a {_value}, unless the result is {_value}\. If the result is {_value}, the "sign" of {_value} is implementation-defined\.$'), - TWO_ARGS_DIFFERENT_SIGNS_BOTH = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs and are both {_value}, the result has a {_value}\.$'), - TWO_ARGS_EVEN_IF = regex.compile(rf'^- +If `x2_i` is {_value}, the result is {_value}, even if `x1_i` is {_value}\.$'), - - REMAINING = regex.compile(r"^- +In the remaining cases, (.*)$"), -) - - -def parse_value(value, arg): - if value == 'NaN': - return f"NaN({arg}.shape, {arg}.dtype)" - elif value == "+infinity": - return f"infinity({arg}.shape, {arg}.dtype)" - elif value == "-infinity": - return f"-infinity({arg}.shape, {arg}.dtype)" - elif value in ["0", "+0"]: - return f"zero({arg}.shape, {arg}.dtype)" - elif value == "-0": - return f"-zero({arg}.shape, {arg}.dtype)" - elif value in ["1", "+1"]: - return f"one({arg}.shape, {arg}.dtype)" - elif value == "-1": - return f"-one({arg}.shape, {arg}.dtype)" - # elif value == 'signed infinity': - elif value == 'signed zero': - return f"zero({arg}.shape, {arg}.dtype))" - elif 'π' in value: - value = regex.sub(r'(\d+)π', r'\1*π', value) - return value.replace('π', f'π({arg}.shape, {arg}.dtype)') - elif 'x1_i' in value or 'x2_i' in value: - return value - elif value.startswith('where('): - return value - elif value in ['finite', 'nonzero', 'nonzero finite', - "integer", "odd integer", "positive", - "negative", "positive mathematical sign", - "negative mathematical sign"]: - return value - # There's no way to remove the parenthetical from the matching group in - # the regular expression. - elif value == "positive (i.e., greater than `0`) finite": - return "positive finite" - elif value == 'negative (i.e., less than `0`) finite': - return "negative finite" - else: - raise RuntimeError(f"Unexpected input value {value!r}") - -def _check_exactly_equal(typ, value): - if not typ == 'exactly_equal': - raise RuntimeError(f"Unexpected mask type {typ}: {value}") - -def get_mask(typ, arg, value): - if typ.startswith("not"): - if value.startswith('zero('): - return f"notequal({arg}, {value})" - return f"logical_not({get_mask(typ[len('not'):], arg, value)})" - if typ.startswith("abs"): - return get_mask(typ[len("abs"):], f"abs({arg})", value) - if value == 'finite': - _check_exactly_equal(typ, value) - return f"isfinite({arg})" - elif value == 'nonzero': - _check_exactly_equal(typ, value) - return f"non_zero({arg})" - elif value == 'positive finite': - _check_exactly_equal(typ, value) - return f"logical_and(isfinite({arg}), ispositive({arg}))" - elif value == 'negative finite': - _check_exactly_equal(typ, value) - return f"logical_and(isfinite({arg}), isnegative({arg}))" - elif value == 'nonzero finite': - _check_exactly_equal(typ, value) - return f"logical_and(isfinite({arg}), non_zero({arg}))" - elif value == 'positive': - _check_exactly_equal(typ, value) - return f"ispositive({arg})" - elif value == 'positive mathematical sign': - _check_exactly_equal(typ, value) - return f"positive_mathematical_sign({arg})" - elif value == 'negative': - _check_exactly_equal(typ, value) - return f"isnegative({arg})" - elif value == 'negative mathematical sign': - _check_exactly_equal(typ, value) - return f"negative_mathematical_sign({arg})" - elif value == 'integer': - _check_exactly_equal(typ, value) - return f"isintegral({arg})" - elif value == 'odd integer': - _check_exactly_equal(typ, value) - return f"isodd({arg})" - elif 'x_i' in value: - return f"{typ}({arg}, {value.replace('x_i', 'arg1')})" - elif 'x1_i' in value: - return f"{typ}({arg}, {value.replace('x1_i', 'arg1')})" - elif 'x2_i' in value: - return f"{typ}({arg}, {value.replace('x2_i', 'arg2')})" - return f"{typ}({arg}, {value})" - -def get_assert(typ, result): - # TODO: Refactor this so typ is actually what it should be - if result == "signed infinity": - _check_exactly_equal(typ, result) - return "assert_isinf(res[mask])" - elif result == "positive": - _check_exactly_equal(typ, result) - return "assert_positive(res[mask])" - elif result == "positive mathematical sign": - _check_exactly_equal(typ, result) - return "assert_positive_mathematical_sign(res[mask])" - elif result == "negative": - _check_exactly_equal(typ, result) - return "assert_negative(res[mask])" - elif result == "negative mathematical sign": - _check_exactly_equal(typ, result) - return "assert_negative_mathematical_sign(res[mask])" - elif result == 'even integer closest to `x_i`': - _check_exactly_equal(typ, result) - return "assert_iseven(res[mask])\n assert_positive(subtract(one(arg1[mask].shape, arg1[mask].dtype), abs(subtract(arg1[mask], res[mask]))))" - elif 'x_i' in result: - return f"assert_{typ}(res[mask], ({result.replace('x_i', 'arg1')})[mask])" - elif 'x1_i' in result: - return f"assert_{typ}(res[mask], ({result.replace('x1_i', 'arg1')})[mask])" - elif 'x2_i' in result: - return f"assert_{typ}(res[mask], ({result.replace('x2_i', 'arg2')})[mask])" - - # TODO: Get use something better than arg1 here for the arg - result = parse_value(result, "arg1") - try: - # This won't catch all unknown values, but will catch some. - ast.parse(result) - except SyntaxError: - raise RuntimeError(f"Unexpected result value {result!r} for {typ} (bad syntax)") - return f"assert_{typ}(res[mask], ({result})[mask])" - -ONE_ARG_TEMPLATE = """ -{decorator} -def test_{func}_special_cases_{test_name_extra}(arg1): - {doc} - res = {func}(arg1) - mask = {mask} - {assertion} -""" - -TWO_ARGS_TEMPLATE = """ -{decorator} -def test_{func}_special_cases_{test_name_extra}(arg1, arg2): - {doc} - res = {func}(arg1, arg2) - mask = {mask} - {assertion} -""" - -OP_ONE_ARG_TEMPLATE = """ -{decorator} -def test_{op}_special_cases_{test_name_extra}(arg1): - {doc} - res = (arg1).{func}() - mask = {mask} - {assertion} -""" - -OP_TWO_ARGS_TEMPLATE = """ -{decorator} -def test_{op}_special_cases_{test_name_extra}(arg1, arg2): - {doc} - res = arg1.{func}(arg2) - mask = {mask} - {assertion} -""" - -IOP_TWO_ARGS_TEMPLATE = """ -{decorator} -def test_{op}_special_cases_{test_name_extra}(arg1, arg2): - {doc} - res = asarray(arg1, copy=True) - {op}(res, arg2) - mask = {mask} - {assertion} -""" - -REMAINING_TEMPLATE = """# TODO: Implement REMAINING test for: -# {text} -""" - -def generate_special_case_test(func, typ, m, test_name_extra, sigs): - doc = f'''""" - Special case test for `{sigs[func]}`: - - {m.group(0)} - - """''' - if typ.startswith("ONE_ARG"): - decorator = "@given(numeric_arrays)" - if typ == "ONE_ARG_EQUAL": - value1, result = m.groups() - value1 = parse_value(value1, 'arg1') - mask = get_mask("exactly_equal", "arg1", value1) - elif typ == "ONE_ARG_GREATER": - value1, result = m.groups() - value1 = parse_value(value1, 'arg1') - mask = get_mask("greater", "arg1", value1) - elif typ == "ONE_ARG_LESS": - value1, result = m.groups() - value1 = parse_value(value1, 'arg1') - mask = get_mask("less", "arg1", value1) - elif typ == "ONE_ARG_EITHER": - value1, value2, result = m.groups() - value1 = parse_value(value1, 'arg1') - value2 = parse_value(value2, 'arg1') - mask1 = get_mask("exactly_equal", "arg1", value1) - mask2 = get_mask("exactly_equal", "arg1", value2) - mask = f"logical_or({mask1}, {mask2})" - elif typ == "ONE_ARG_ALREADY_INTEGER_VALUED": - result, = m.groups() - mask = parse_value("integer", "arg1") - elif typ == "ONE_ARG_TWO_INTEGERS_EQUALLY_CLOSE": - result, = m.groups() - mask = "logical_and(not_equal(floor(arg1), ceil(arg1)), equal(subtract(arg1, floor(arg1)), subtract(ceil(arg1), arg1)))" - else: - raise ValueError(f"Unrecognized special value type {typ}") - assertion = get_assert("exactly_equal", result) - if func in OPS: - return OP_ONE_ARG_TEMPLATE.format( - decorator=decorator, - func=func, - op=func[2:-2], - test_name_extra=test_name_extra, - doc=doc, - mask=mask, - assertion=assertion, - ) - else: - return ONE_ARG_TEMPLATE.format( - decorator=decorator, - func=func, - test_name_extra=test_name_extra, - doc=doc, - mask=mask, - assertion=assertion, - ) - - elif typ.startswith("TWO_ARGS"): - decorator = "@given(numeric_arrays, numeric_arrays)" - if typ in [ - "TWO_ARGS_EQUAL__EQUAL", - "TWO_ARGS_GREATER__EQUAL", - "TWO_ARGS_LESS__EQUAL", - "TWO_ARGS_EQUAL__GREATER", - "TWO_ARGS_EQUAL__LESS", - "TWO_ARGS_EQUAL__NOTEQUAL", - "TWO_ARGS_NOTEQUAL__EQUAL", - "TWO_ARGS_ABSEQUAL__EQUAL", - "TWO_ARGS_ABSGREATER__EQUAL", - "TWO_ARGS_ABSLESS__EQUAL", - "TWO_ARGS_GREATER_EQUAL__EQUAL", - "TWO_ARGS_LESS_EQUAL__EQUAL", - "TWO_ARGS_EQUAL__LESS_EQUAL", - "TWO_ARGS_EQUAL__LESS_NOTEQUAL", - "TWO_ARGS_EQUAL__GREATER_EQUAL", - "TWO_ARGS_EQUAL__GREATER_NOTEQUAL", - "TWO_ARGS_LESS_EQUAL__EQUAL_NOTEQUAL", - "TWO_ARGS_EITHER__EQUAL", - "TWO_ARGS_EQUAL__EITHER", - "TWO_ARGS_EITHER__EITHER", - ]: - arg1typs, arg2typs = [i.split('_') for i in typ[len("TWO_ARGS_"):].split("__")] - if arg1typs == ["EITHER"]: - arg1typs = ["EITHER_EQUAL", "EITHER_EQUAL"] - if arg2typs == ["EITHER"]: - arg2typs = ["EITHER_EQUAL", "EITHER_EQUAL"] - *values, result = m.groups() - if len(values) != len(arg1typs) + len(arg2typs): - raise RuntimeError(f"Unexpected number of parsed values for {typ}: len({values}) != len({arg1typs}) + len({arg2typs})") - arg1values, arg2values = values[:len(arg1typs)], values[len(arg1typs):] - arg1values = [parse_value(value, 'arg1') for value in arg1values] - arg2values = [parse_value(value, 'arg2') for value in arg2values] - - tomask = lambda t: t.lower().replace("either_equal", "equal").replace("equal", "exactly_equal") - value1masks = [get_mask(tomask(t), 'arg1', v) for t, v in - zip(arg1typs, arg1values)] - value2masks = [get_mask(tomask(t), 'arg2', v) for t, v in - zip(arg2typs, arg2values)] - if len(value1masks) > 1: - if arg1typs[0] == "EITHER_EQUAL": - mask1 = f"logical_or({value1masks[0]}, {value1masks[1]})" - else: - mask1 = f"logical_and({value1masks[0]}, {value1masks[1]})" - else: - mask1 = value1masks[0] - if len(value2masks) > 1: - if arg2typs[0] == "EITHER_EQUAL": - mask2 = f"logical_or({value2masks[0]}, {value2masks[1]})" - else: - mask2 = f"logical_and({value2masks[0]}, {value2masks[1]})" - else: - mask2 = value2masks[0] - - mask = f"logical_and({mask1}, {mask2})" - assertion = get_assert("exactly_equal", result) - - elif typ == "TWO_ARGS_EITHER": - value, result = m.groups() - value = parse_value(value, "arg1") - mask1 = get_mask("exactly_equal", "arg1", value) - mask2 = get_mask("exactly_equal", "arg2", value) - mask = f"logical_or({mask1}, {mask2})" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_SAME_SIGN": - result, = m.groups() - mask = "same_sign(arg1, arg2)" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_SAME_SIGN_EXCEPT": - result, value, value1, value2 = m.groups() - assert value == value1 == value2 - value = parse_value(value, "res") - mask = f"logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, {value})))" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_SAME_SIGN_BOTH": - value, result = m.groups() - mask1 = get_mask("exactly_equal", "arg1", value) - mask2 = get_mask("exactly_equal", "arg2", value) - mask = f"logical_and(same_sign(arg1, arg2), logical_and({mask1}, {mask2}))" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_DIFFERENT_SIGNS": - result, = m.groups() - mask = "logical_not(same_sign(arg1, arg2))" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_DIFFERENT_SIGNS_EXCEPT": - result, value, value1, value2 = m.groups() - assert value == value1 == value2 - value = parse_value(value, "res") - mask = f"logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, {value})))" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_DIFFERENT_SIGNS_BOTH": - value, result = m.groups() - mask1 = get_mask("exactly_equal", "arg1", value) - mask2 = get_mask("exactly_equal", "arg2", value) - mask = f"logical_and(logical_not(same_sign(arg1, arg2)), logical_and({mask1}, {mask2}))" - assertion = get_assert("exactly_equal", result) - elif typ == "TWO_ARGS_EVEN_IF": - value1, result, value2 = m.groups() - value1 = parse_value(value1, "arg2") - mask = get_mask("exactly_equal", "arg2", value1) - assertion = get_assert("exactly_equal", result) - else: - raise ValueError(f"Unrecognized special value type {typ}") - - if func in OPS: - return OP_TWO_ARGS_TEMPLATE.format( - decorator=decorator, - func=func, - op=func[2:-2], - test_name_extra=test_name_extra, - doc=doc, - mask=mask, - assertion=assertion, - ) - elif func in IOPS: - return IOP_TWO_ARGS_TEMPLATE.format( - decorator=decorator, - func=func, - op=func[2:-2], - test_name_extra=test_name_extra, - doc=doc, - mask=mask, - assertion=assertion, - ) - else: - return TWO_ARGS_TEMPLATE.format( - decorator=decorator, - func=func, - test_name_extra=test_name_extra, - doc=doc, - mask=mask, - assertion=assertion, - ) - - elif typ == "REMAINING": - return REMAINING_TEMPLATE.format(text=m.group(0)) - - else: - raise RuntimeError(f"Unexpected type {typ}") - -def parse_special_cases(spec_text, verbose=False) -> Dict[str, DefaultDict[str, List[regex.Match]]]: - special_cases = {} - in_block = False - name = None - for line in spec_text.splitlines(): - func_m = FUNCTION_HEADER_RE.match(line) - meth_m = METHOD_HEADER_RE.match(line) - if func_m or meth_m: - name = func_m.group(1) if func_m else meth_m.group(1) - special_cases[name] = defaultdict(list) - continue - if line == '#### Special Cases': - in_block = True - continue - elif line.startswith('#'): - in_block = False - continue - if in_block: - if '- ' not in line or name is None: - continue - for typ, reg in SPECIAL_CASE_REGEXS.items(): - m = reg.match(line) - if m: - if verbose: - print(f"Matched {typ} for {name}: {m.groups()}") - special_cases[name][typ].append(m) - break - else: - raise ValueError(f"Unrecognized special case string for '{name}':\n{line}") - - return special_cases - -def make_special_case_tests(func, special_cases: Dict[str, DefaultDict[str, List[regex.Match]]], sigs) -> List[str]: - tests = [] - for typ in special_cases[func]: - multiple = len(special_cases[func][typ]) > 1 - for i, m in enumerate(special_cases[func][typ], 1): - test_name_extra = typ.lower() - if multiple: - test_name_extra += f"_{i}" - test = generate_special_case_test(func, typ, m, test_name_extra, sigs) - assert test is not None # sanity check - tests.append(test) - return tests - - -PARAMETER_RE = regex.compile(r"- +\*\*(.*)\*\*: _(.*)_") -def parse_annotations(spec_text, all_annotations, verbose=False): - annotations = defaultdict(dict) - in_block = False - is_returns = False - for line in spec_text.splitlines(): - m = HEADER_RE.match(line) - if m: - name = m.group(1).replace('-', '_') - continue - m = ALIAS_RE.match(line) - if m: - alias_name = m.group(1).replace('-', '_') - if alias_name not in all_annotations: - print(f"Warning: No annotations for aliased function {name}") - else: - annotations[name] = all_annotations[alias_name] - continue - if line == '#### Parameters': - in_block = True - continue - elif line == '#### Returns': - in_block = True - is_returns = True - continue - elif line.startswith('#'): - in_block = False - continue - if in_block: - if not line.startswith('- '): - continue - m = PARAMETER_RE.match(line) - if m: - param, typ = m.groups() - if is_returns: - param = 'return' - is_returns = False - if name == '__setitem__': - # setitem returns None so it doesn't have a Returns - # section in the spec - annotations[name]['return'] = 'None' - typ = clean_type(typ) - if verbose: - print(f"Matched parameter for {name}: {param}: {typ}") - annotations[name][param] = typ - else: - raise ValueError(f"Unrecognized annotation for '{name}':\n{line}") - - return annotations - -def clean_type(typ): - typ = regex.sub(r'<(.*?)>', lambda m: m.group(1).replace(' ', '_'), typ) - typ = typ.replace('\\', '') - typ = typ.replace(' ', '') - typ = typ.replace(',', ', ') - typ = typ.replace('enum.', '') - return typ - -def add_annotation(sig, annotation): - if 'return' not in annotation: - raise RuntimeError(f"No return annotation for {sig}") - if 'out' in annotation: - raise RuntimeError(f"Error parsing annotations for {sig}") - for param, typ in annotation.items(): - if param == 'return': - sig = f"{sig} -> {typ}" - continue - PARAM_DEFAULT = regex.compile(rf"([\( ]{param})=") - sig2 = PARAM_DEFAULT.sub(rf'\1: {typ} = ', sig) - if sig2 != sig: - sig = sig2 - continue - PARAM = regex.compile(rf"([\( ]\*?{param})([,\)])") - sig2 = PARAM.sub(rf'\1: {typ}\2', sig) - if sig2 != sig: - sig = sig2 - continue - raise RuntimeError(f"Parameter {param} not found in {sig}") - return sig - -if __name__ == '__main__': - main() diff --git a/requirements.txt b/requirements.txt index 95a49cfa..fbc3fca3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ pytest hypothesis>=6.31.1 ndindex>=1.6 -regex -removestar