|
3 | 3 | """
|
4 | 4 | import math
|
5 | 5 | import operator
|
| 6 | +from copy import copy |
6 | 7 | from enum import Enum, auto
|
7 | 8 | from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
|
8 | 9 |
|
@@ -103,8 +104,6 @@ def default_filter(s: Scalar) -> bool:
|
103 | 104 | """
|
104 | 105 | if isinstance(s, int): # note bools are ints
|
105 | 106 | return True
|
106 |
| - elif isinstance(s, complex): |
107 |
| - return default_filter(s.real) and default_filter(s.imag) |
108 | 107 | else:
|
109 | 108 | return math.isfinite(s) and s != 0
|
110 | 109 |
|
@@ -255,6 +254,9 @@ def unary_assert_against_refimpl(
|
255 | 254 | m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]]
|
256 | 255 | else:
|
257 | 256 | m, M = dh.dtype_ranges[res.dtype]
|
| 257 | + if in_.dtype in dh.complex_dtypes: |
| 258 | + component_filter = copy(filter_) |
| 259 | + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) |
258 | 260 | for idx in sh.ndindex(in_.shape):
|
259 | 261 | scalar_i = in_stype(in_[idx])
|
260 | 262 | if not filter_(scalar_i):
|
@@ -313,6 +315,9 @@ def binary_assert_against_refimpl(
|
313 | 315 | if res_stype is None:
|
314 | 316 | res_stype = in_stype
|
315 | 317 | m, M = dh.dtype_ranges.get(res.dtype, (None, None))
|
| 318 | + if left.dtype in dh.complex_dtypes: |
| 319 | + component_filter = copy(filter_) |
| 320 | + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) |
316 | 321 | for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
|
317 | 322 | scalar_l = in_stype(left[l_idx])
|
318 | 323 | scalar_r = in_stype(right[r_idx])
|
|
0 commit comments