Skip to content

Remove use of _UndefinedStub in dtype_helpers.py #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 140 additions & 105 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
from collections import defaultdict
from collections.abc import Mapping
from functools import lru_cache
from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union
from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union
from warnings import warn

from . import _array_module as xp
from . import api_version
from ._array_module import _UndefinedStub
from ._array_module import mod as _xp
from ._array_module import mod as xp
from .stubs import name_to_func
from .typing import DataType, ScalarType

__all__ = [
"uint_names",
"int_names",
"all_int_names",
"float_names",
"real_names",
"complex_names",
"numeric_names",
"dtype_names",
"int_dtypes",
"uint_dtypes",
"all_int_dtypes",
Expand Down Expand Up @@ -90,27 +96,43 @@ def __repr__(self):
return f"EqualityMapping({self})"


def _filter_stubs(*args):
for a in args:
if not isinstance(a, _UndefinedStub):
yield a
uint_names = ("uint8", "uint16", "uint32", "uint64")
int_names = ("int8", "int16", "int32", "int64")
all_int_names = uint_names + int_names
float_names = ("float32", "float64")
real_names = uint_names + int_names + float_names
complex_names = ("complex64", "complex128")
numeric_names = real_names + complex_names
dtype_names = ("bool",) + numeric_names


_uint_names = ("uint8", "uint16", "uint32", "uint64")
_int_names = ("int8", "int16", "int32", "int64")
_float_names = ("float32", "float64")
_real_names = _uint_names + _int_names + _float_names
_complex_names = ("complex64", "complex128")
_numeric_names = _real_names + _complex_names
_dtype_names = ("bool",) + _numeric_names
_name_to_dtype = {}
for name in dtype_names:
try:
dtype = getattr(xp, name)
except AttributeError:
continue
_name_to_dtype[name] = dtype
dtype_to_name = EqualityMapping([(d, n) for n, d in _name_to_dtype.items()])


uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
dtypes = []
for name in names:
try:
dtype = _name_to_dtype[name]
except KeyError:
continue
dtypes.append(dtype)
return tuple(dtypes)


uint_dtypes = _make_dtype_tuple_from_names(uint_names)
int_dtypes = _make_dtype_tuple_from_names(int_names)
float_dtypes = _make_dtype_tuple_from_names(float_names)
all_int_dtypes = uint_dtypes + int_dtypes
real_dtypes = all_int_dtypes + float_dtypes
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
complex_dtypes = _make_dtype_tuple_from_names(complex_names)
numeric_dtypes = real_dtypes
if api_version > "2021.12":
numeric_dtypes += complex_dtypes
Expand All @@ -121,16 +143,6 @@ def _filter_stubs(*args):
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes


_dtype_name_pairs = []
for name in _dtype_names:
try:
dtype = getattr(_xp, name)
except AttributeError:
continue
_dtype_name_pairs.append((dtype, name))
dtype_to_name = EqualityMapping(_dtype_name_pairs)


dtype_to_scalars = EqualityMapping(
[
(xp.bool, [bool]),
Expand Down Expand Up @@ -179,47 +191,59 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
return bool


def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
dtype_value_pairs = []
for name, value in mapping.items():
assert isinstance(name, str) and name in dtype_names # sanity check
try:
dtype = getattr(xp, name)
except AttributeError:
continue
dtype_value_pairs.append((dtype, value))
return EqualityMapping(dtype_value_pairs)


class MinMax(NamedTuple):
min: Union[int, float]
max: Union[int, float]


dtype_ranges = EqualityMapping(
[
(xp.int8, MinMax(-128, +127)),
(xp.int16, MinMax(-32_768, +32_767)),
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
(xp.uint8, MinMax(0, +255)),
(xp.uint16, MinMax(0, +65_535)),
(xp.uint32, MinMax(0, +4_294_967_295)),
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
]
dtype_ranges = _make_dtype_mapping_from_names(
{
"int8": MinMax(-128, +127),
"int16": MinMax(-32_768, +32_767),
"int32": MinMax(-2_147_483_648, +2_147_483_647),
"int64": MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
"uint8": MinMax(0, +255),
"uint16": MinMax(0, +65_535),
"uint32": MinMax(0, +4_294_967_295),
"uint64": MinMax(0, +18_446_744_073_709_551_615),
"float32": MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
"float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
}
)


dtype_nbits = EqualityMapping(
[(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)]
+ [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)]
+ [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)]
+ [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)]
+ [(d, 128) for d in _filter_stubs(xp.complex128)]
)
r_nbits = re.compile(r"[a-z]+([0-9]+)")
_dtype_nbits: Dict[str, int] = {}
for name in numeric_names:
m = r_nbits.fullmatch(name)
assert m is not None # sanity check / for mypy
_dtype_nbits[name] = int(m.group(1))
dtype_nbits = _make_dtype_mapping_from_names(_dtype_nbits)


dtype_signed = EqualityMapping(
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
dtype_signed = _make_dtype_mapping_from_names(
{**{name: True for name in int_names}, **{name: False for name in uint_names}}
)


dtype_components = EqualityMapping(
[(xp.complex64, xp.float32), (xp.complex128, xp.float64)]
dtype_components = _make_dtype_mapping_from_names(
{"complex64": xp.float32, "complex128": xp.float64}
)


if isinstance(xp.asarray, _UndefinedStub):
if not hasattr(xp, "asarray"):
default_int = xp.int32
default_float = xp.float32
warn(
Expand All @@ -243,60 +267,73 @@ class MinMax(NamedTuple):
else:
default_complex = None
if dtype_nbits[default_int] == 32:
default_uint = xp.uint32
default_uint = getattr(xp, "uint32", None)
else:
default_uint = xp.uint64

default_uint = getattr(xp, "uint64", None)

_numeric_promotions = [
_promotion_table: Dict[Tuple[str, str], str] = {
("bool", "bool"): "bool",
# ints
((xp.int8, xp.int8), xp.int8),
((xp.int8, xp.int16), xp.int16),
((xp.int8, xp.int32), xp.int32),
((xp.int8, xp.int64), xp.int64),
((xp.int16, xp.int16), xp.int16),
((xp.int16, xp.int32), xp.int32),
((xp.int16, xp.int64), xp.int64),
((xp.int32, xp.int32), xp.int32),
((xp.int32, xp.int64), xp.int64),
((xp.int64, xp.int64), xp.int64),
("int8", "int8"): "int8",
("int8", "int16"): "int16",
("int8", "int32"): "int32",
("int8", "int64"): "int64",
("int16", "int16"): "int16",
("int16", "int32"): "int32",
("int16", "int64"): "int64",
("int32", "int32"): "int32",
("int32", "int64"): "int64",
("int64", "int64"): "int64",
# uints
((xp.uint8, xp.uint8), xp.uint8),
((xp.uint8, xp.uint16), xp.uint16),
((xp.uint8, xp.uint32), xp.uint32),
((xp.uint8, xp.uint64), xp.uint64),
((xp.uint16, xp.uint16), xp.uint16),
((xp.uint16, xp.uint32), xp.uint32),
((xp.uint16, xp.uint64), xp.uint64),
((xp.uint32, xp.uint32), xp.uint32),
((xp.uint32, xp.uint64), xp.uint64),
((xp.uint64, xp.uint64), xp.uint64),
("uint8", "uint8"): "uint8",
("uint8", "uint16"): "uint16",
("uint8", "uint32"): "uint32",
("uint8", "uint64"): "uint64",
("uint16", "uint16"): "uint16",
("uint16", "uint32"): "uint32",
("uint16", "uint64"): "uint64",
("uint32", "uint32"): "uint32",
("uint32", "uint64"): "uint64",
("uint64", "uint64"): "uint64",
# ints and uints (mixed sign)
((xp.int8, xp.uint8), xp.int16),
((xp.int8, xp.uint16), xp.int32),
((xp.int8, xp.uint32), xp.int64),
((xp.int16, xp.uint8), xp.int16),
((xp.int16, xp.uint16), xp.int32),
((xp.int16, xp.uint32), xp.int64),
((xp.int32, xp.uint8), xp.int32),
((xp.int32, xp.uint16), xp.int32),
((xp.int32, xp.uint32), xp.int64),
((xp.int64, xp.uint8), xp.int64),
((xp.int64, xp.uint16), xp.int64),
((xp.int64, xp.uint32), xp.int64),
("int8", "uint8"): "int16",
("int8", "uint16"): "int32",
("int8", "uint32"): "int64",
("int16", "uint8"): "int16",
("int16", "uint16"): "int32",
("int16", "uint32"): "int64",
("int32", "uint8"): "int32",
("int32", "uint16"): "int32",
("int32", "uint32"): "int64",
("int64", "uint8"): "int64",
("int64", "uint16"): "int64",
("int64", "uint32"): "int64",
# floats
((xp.float32, xp.float32), xp.float32),
((xp.float32, xp.float64), xp.float64),
((xp.float64, xp.float64), xp.float64),
("float32", "float32"): "float32",
("float32", "float64"): "float64",
("float64", "float64"): "float64",
# complex
((xp.complex64, xp.complex64), xp.complex64),
((xp.complex64, xp.complex128), xp.complex128),
((xp.complex128, xp.complex128), xp.complex128),
]
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
_promotion_table = list(set(_numeric_promotions))
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
promotion_table = EqualityMapping(_promotion_table)
("complex64", "complex64"): "complex64",
("complex64", "complex128"): "complex128",
("complex128", "complex128"): "complex128",
}
_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
for (in_name1, in_name2), res_name in _promotion_table.items():
try:
in_dtype1 = getattr(xp, in_name1)
except AttributeError:
continue
try:
in_dtype2 = getattr(xp, in_name2)
except AttributeError:
continue
try:
res_dtype = getattr(xp, res_name)
except AttributeError:
continue
_promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
promotion_table = EqualityMapping(_promotion_table_pairs)


def result_type(*dtypes: DataType):
Expand Down Expand Up @@ -325,6 +362,7 @@ def result_type(*dtypes: DataType):
}
func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
for name, func in name_to_func.items():
assert func.__doc__ is not None # for mypy
if m := r_in_dtypes.search(func.__doc__):
dtype_category = m.group(1)
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
Expand Down Expand Up @@ -457,11 +495,10 @@ def result_type(*dtypes: DataType):
}


# Construct func_in_dtypes and func_returns bool
for op, elwise_func in op_to_func.items():
func_in_dtypes[op] = func_in_dtypes[elwise_func]
func_returns_bool[op] = func_returns_bool[elwise_func]


inplace_op_to_symbol = {}
for op, symbol in binary_op_to_symbol.items():
if op == "__matmul__" or func_returns_bool[op]:
Expand All @@ -470,8 +507,6 @@ def result_type(*dtypes: DataType):
inplace_op_to_symbol[iop] = f"{symbol}="
func_in_dtypes[iop] = func_in_dtypes[op]
func_returns_bool[iop] = func_returns_bool[op]


func_in_dtypes["__bool__"] = (xp.bool,)
func_in_dtypes["__int__"] = all_int_dtypes
func_in_dtypes["__index__"] = all_int_dtypes
Expand Down
23 changes: 15 additions & 8 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from ._array_module import mod as _xp
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape

pytestmark = pytest.mark.ci
Expand Down Expand Up @@ -241,21 +242,27 @@ def test_setitem_masking(shape, data):
)


def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
def make_scalar_casting_param(
method_name: str, dtype_name: DataType, stype: ScalarType
) -> Param:
return pytest.param(
method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})"
method_name, dtype_name, stype, id=f"{method_name}({dtype_name})"
)


@pytest.mark.parametrize(
"method_name, dtype, stype",
[make_param("__bool__", xp.bool, bool)]
+ [make_param("__int__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)]
+ [make_param("__index__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)]
+ [make_param("__float__", d, float) for d in dh.float_dtypes],
"method_name, dtype_name, stype",
[make_scalar_casting_param("__bool__", "bool", bool)]
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names]
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names]
+ [make_scalar_casting_param("__float__", n, float) for n in dh.float_names],
)
@given(data=st.data())
def test_scalar_casting(method_name, dtype, stype, data):
def test_scalar_casting(method_name, dtype_name, stype, data):
try:
dtype = getattr(_xp, dtype_name)
except AttributeError as e:
pytest.skip(str(e))
x = data.draw(xps.arrays(dtype, shape=()), label="x")
method = getattr(x, method_name)
out = method()
Expand Down
Loading