Skip to content

Commit 7a1e48e

Browse files
committed
Define elwise filters only for component dtypes
1 parent 82e6312 commit 7a1e48e

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import math
55
import operator
6+
from copy import copy
67
from enum import Enum, auto
78
from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
89

@@ -103,8 +104,6 @@ def default_filter(s: Scalar) -> bool:
103104
"""
104105
if isinstance(s, int): # note bools are ints
105106
return True
106-
elif isinstance(s, complex):
107-
return default_filter(s.real) and default_filter(s.imag)
108107
else:
109108
return math.isfinite(s) and s != 0
110109

@@ -255,6 +254,9 @@ def unary_assert_against_refimpl(
255254
m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]]
256255
else:
257256
m, M = dh.dtype_ranges[res.dtype]
257+
if in_.dtype in dh.complex_dtypes:
258+
component_filter = copy(filter_)
259+
filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
258260
for idx in sh.ndindex(in_.shape):
259261
scalar_i = in_stype(in_[idx])
260262
if not filter_(scalar_i):
@@ -313,6 +315,9 @@ def binary_assert_against_refimpl(
313315
if res_stype is None:
314316
res_stype = in_stype
315317
m, M = dh.dtype_ranges.get(res.dtype, (None, None))
318+
if left.dtype in dh.complex_dtypes:
319+
component_filter = copy(filter_)
320+
filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
316321
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
317322
scalar_l = in_stype(left[l_idx])
318323
scalar_r = in_stype(right[r_idx])

0 commit comments

Comments
 (0)