Skip to content

Commit 4ad8996

Browse files
committed
float_dtypes -> real_float_dtypes
Better reflect the specs naming convention and avoid confusion with complex
1 parent fc7c8b7 commit 4ad8996

8 files changed

+24
-24
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
__all__ = [
1414
"uint_names",
1515
"int_names",
16-
"float_names",
16+
"real_float_names",
1717
"real_names",
1818
"complex_names",
1919
"numeric_names",
2020
"dtype_names",
2121
"int_dtypes",
2222
"uint_dtypes",
2323
"all_int_dtypes",
24-
"float_dtypes",
24+
"real_float_dtypes",
2525
"real_dtypes",
2626
"numeric_dtypes",
2727
"all_dtypes",
@@ -96,8 +96,8 @@ def __repr__(self):
9696

9797
uint_names = ("uint8", "uint16", "uint32", "uint64")
9898
int_names = ("int8", "int16", "int32", "int64")
99-
float_names = ("float32", "float64")
100-
real_names = uint_names + int_names + float_names
99+
real_float_names = ("float32", "float64")
100+
real_names = uint_names + int_names + real_float_names
101101
complex_names = ("complex64", "complex128")
102102
numeric_names = real_names + complex_names
103103
dtype_names = ("bool",) + numeric_names
@@ -126,15 +126,15 @@ def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
126126

127127
uint_dtypes = _make_dtype_tuple_from_names(uint_names)
128128
int_dtypes = _make_dtype_tuple_from_names(int_names)
129-
float_dtypes = _make_dtype_tuple_from_names(float_names)
129+
real_float_dtypes = _make_dtype_tuple_from_names(real_float_names)
130130
all_int_dtypes = uint_dtypes + int_dtypes
131-
real_dtypes = all_int_dtypes + float_dtypes
131+
real_dtypes = all_int_dtypes + real_float_dtypes
132132
complex_dtypes = _make_dtype_tuple_from_names(complex_names)
133133
numeric_dtypes = real_dtypes
134134
if api_version > "2021.12":
135135
numeric_dtypes += complex_dtypes
136136
all_dtypes = (xp.bool,) + numeric_dtypes
137-
all_float_dtypes = float_dtypes
137+
all_float_dtypes = real_float_dtypes
138138
if api_version > "2021.12":
139139
all_float_dtypes += complex_dtypes
140140
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
@@ -145,7 +145,7 @@ def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
145145
"signed integer": int_dtypes,
146146
"unsigned integer": uint_dtypes,
147147
"integral": all_int_dtypes,
148-
"real floating": float_dtypes,
148+
"real floating": real_float_dtypes,
149149
"complex floating": complex_dtypes,
150150
"numeric": numeric_dtypes,
151151
}
@@ -162,7 +162,7 @@ def is_float_dtype(dtype):
162162
# See https://github.com/numpy/numpy/issues/18434
163163
if dtype is None:
164164
return False
165-
valid_dtypes = float_dtypes
165+
valid_dtypes = real_float_dtypes
166166
if api_version > "2021.12":
167167
valid_dtypes += complex_dtypes
168168
return dtype in valid_dtypes
@@ -171,7 +171,7 @@ def is_float_dtype(dtype):
171171
def get_scalar_type(dtype: DataType) -> ScalarType:
172172
if dtype in all_int_dtypes:
173173
return int
174-
elif dtype in float_dtypes:
174+
elif dtype in real_float_dtypes:
175175
return float
176176
elif dtype in complex_dtypes:
177177
return complex
@@ -243,7 +243,7 @@ class MinMax(NamedTuple):
243243
if default_int not in int_dtypes:
244244
warn(f"inferred default int is {default_int!r}, which is not an int")
245245
default_float = xp.asarray(float()).dtype
246-
if default_float not in float_dtypes:
246+
if default_float not in real_float_dtypes:
247247
warn(f"inferred default float is {default_float!r}, which is not a float")
248248
if api_version > "2021.12":
249249
default_complex = xp.asarray(complex()).dtype
@@ -344,7 +344,7 @@ def result_type(*dtypes: DataType):
344344
category_to_dtypes = {
345345
"boolean": (xp.bool,),
346346
"integer": all_int_dtypes,
347-
"floating-point": float_dtypes,
347+
"floating-point": real_float_dtypes,
348348
"numeric": numeric_dtypes,
349349
"integer or boolean": bool_and_all_int_dtypes,
350350
}
@@ -358,7 +358,7 @@ def result_type(*dtypes: DataType):
358358
dtypes = category_to_dtypes[dtype_category]
359359
func_in_dtypes[name] = dtypes
360360
# See https://github.com/data-apis/array-api/pull/413
361-
func_in_dtypes["expm1"] = float_dtypes
361+
func_in_dtypes["expm1"] = real_float_dtypes
362362

363363

364364
func_returns_bool = {
@@ -498,7 +498,7 @@ def result_type(*dtypes: DataType):
498498
func_in_dtypes["__bool__"] = (xp.bool,)
499499
func_in_dtypes["__int__"] = all_int_dtypes
500500
func_in_dtypes["__index__"] = all_int_dtypes
501-
func_in_dtypes["__float__"] = float_dtypes
501+
func_in_dtypes["__float__"] = real_float_dtypes
502502
func_in_dtypes["from_dlpack"] = numeric_dtypes
503503
func_in_dtypes["__dlpack__"] = numeric_dtypes
504504

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
shared_dtypes = shared(dtypes, key="dtype")
4040
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
4141

42-
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes]
42+
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
4343
_sorted_dtypes = [d for category in _dtype_categories for d in category]
4444

4545
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1616
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1717

18-
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
18+
@given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes))
1919
def test_mutually_promotable_dtypes(pair):
2020
assert pair in (
2121
(xp.float32, xp.float32),

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def assert_array_elements(
446446
dh.result_type(out.dtype, expected.dtype) # sanity check
447447
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
448448
f_func = f"[{func_name}({fmt_kw(kw)})]"
449-
if out.dtype in dh.float_dtypes:
449+
if out.dtype in dh.real_float_dtypes:
450450
for idx in sh.ndindex(out.shape):
451451
at_out = out[idx]
452452
at_expected = expected[idx]

array_api_tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
252252
[make_param("__bool__", xp.bool, bool)]
253253
+ [make_param("__int__", d, int) for d in dh.all_int_dtypes]
254254
+ [make_param("__index__", d, int) for d in dh.all_int_dtypes]
255-
+ [make_param("__float__", d, float) for d in dh.float_dtypes],
255+
+ [make_param("__float__", d, float) for d in dh.real_float_dtypes],
256256
)
257257
@given(data=st.data())
258258
def test_scalar_casting(method_name, dtype, stype, data):

array_api_tests/test_data_type_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_can_cast(_from, to, data):
123123
expected = to == xp.bool
124124
else:
125125
same_family = None
126-
for dtypes in [dh.all_int_dtypes, dh.float_dtypes, dh.complex_dtypes]:
126+
for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]:
127127
if _from in dtypes:
128128
same_family = to in dtypes
129129
break
@@ -145,7 +145,7 @@ def make_dtype_id(dtype: DataType) -> str:
145145
return dh.dtype_to_name[dtype]
146146

147147

148-
@pytest.mark.parametrize("dtype", dh.float_dtypes, ids=make_dtype_id)
148+
@pytest.mark.parametrize("dtype", dh.real_float_dtypes, ids=make_dtype_id)
149149
def test_finfo(dtype):
150150
out = xp.finfo(dtype)
151151
f_func = f"[finfo({dh.dtype_to_name[dtype]})]"

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def test_atan(x):
776776
unary_assert_against_refimpl("atan", x, out, math.atan)
777777

778778

779-
@given(*hh.two_mutual_arrays(dh.float_dtypes))
779+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
780780
def test_atan2(x1, x2):
781781
out = xp.atan2(x1, x2)
782782
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
@@ -1204,7 +1204,7 @@ def logaddexp(l: float, r: float) -> float:
12041204
return math.log(math.exp(l) + math.exp(r))
12051205

12061206

1207-
@given(*hh.two_mutual_arrays(dh.float_dtypes))
1207+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
12081208
def test_logaddexp(x1, x2):
12091209
out = xp.logaddexp(x1, x2)
12101210
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)

array_api_tests/test_special_cases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@ def test_unary(func_name, func, case, x, data):
12311231

12321232

12331233
x1_strat, x2_strat = hh.two_mutual_arrays(
1234-
dtypes=dh.float_dtypes,
1234+
dtypes=dh.real_float_dtypes,
12351235
two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1),
12361236
)
12371237

@@ -1277,7 +1277,7 @@ def test_binary(func_name, func, case, x1, x2, data):
12771277

12781278
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
12791279
@given(
1280-
oneway_dtypes=hh.oneway_promotable_dtypes(dh.float_dtypes),
1280+
oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes),
12811281
oneway_shapes=hh.oneway_broadcastable_shapes(),
12821282
data=st.data(),
12831283
)

0 commit comments

Comments
 (0)