Skip to content

Commit 6d67d46

Browse files
committed
BUG: reject ndarrays in binary operators
1 parent d086c61 commit 6d67d46

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

array_api_strict/_array_object.py

+11
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def _check_device(self, other):
234234
elif isinstance(other, Array):
235235
if self.device != other.device:
236236
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
237+
else:
238+
raise TypeError(f"Cannot combine an Array with {type(other)}.")
237239

238240
# Helper function to match the type promotion rules in the spec
239241
def _promote_scalar(self, scalar):
@@ -1066,6 +1068,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
10661068
"""
10671069
Performs the operation __imod__.
10681070
"""
1071+
self._check_device(other)
10691072
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
10701073
if other is NotImplemented:
10711074
return other
@@ -1088,6 +1091,7 @@ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array:
10881091
"""
10891092
Performs the operation __imul__.
10901093
"""
1094+
self._check_device(other)
10911095
other = self._check_allowed_dtypes(other, "numeric", "__imul__")
10921096
if other is NotImplemented:
10931097
return other
@@ -1110,6 +1114,7 @@ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array:
11101114
"""
11111115
Performs the operation __ior__.
11121116
"""
1117+
self._check_device(other)
11131118
other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
11141119
if other is NotImplemented:
11151120
return other
@@ -1132,6 +1137,7 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
11321137
"""
11331138
Performs the operation __ipow__.
11341139
"""
1140+
self._check_device(other)
11351141
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
11361142
if other is NotImplemented:
11371143
return other
@@ -1144,6 +1150,7 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
11441150
"""
11451151
from ._elementwise_functions import pow
11461152

1153+
self._check_device(other)
11471154
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
11481155
if other is NotImplemented:
11491156
return other
@@ -1155,6 +1162,7 @@ def __irshift__(self: Array, other: Union[int, Array], /) -> Array:
11551162
"""
11561163
Performs the operation __irshift__.
11571164
"""
1165+
self._check_device(other)
11581166
other = self._check_allowed_dtypes(other, "integer", "__irshift__")
11591167
if other is NotImplemented:
11601168
return other
@@ -1177,6 +1185,7 @@ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array:
11771185
"""
11781186
Performs the operation __isub__.
11791187
"""
1188+
self._check_device(other)
11801189
other = self._check_allowed_dtypes(other, "numeric", "__isub__")
11811190
if other is NotImplemented:
11821191
return other
@@ -1199,6 +1208,7 @@ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array:
11991208
"""
12001209
Performs the operation __itruediv__.
12011210
"""
1211+
self._check_device(other)
12021212
other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
12031213
if other is NotImplemented:
12041214
return other
@@ -1221,6 +1231,7 @@ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array:
12211231
"""
12221232
Performs the operation __ixor__.
12231233
"""
1234+
self._check_device(other)
12241235
other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
12251236
if other is NotImplemented:
12261237
return other

array_api_strict/tests/test_array_object.py

+8
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ def _array_vals():
212212
else:
213213
assert_raises(TypeError, lambda: getattr(x, _op)(y))
214214

215+
# finally, test that array op ndarray raises
216+
# XXX: as long as there is __array__, __rop__s still
217+
# return ndarrays
218+
if not _op.startswith("__r"):
219+
with assert_raises(TypeError):
220+
getattr(x, _op)(y._array)
221+
222+
215223
unary_op_dtypes = {
216224
"__abs__": "numeric",
217225
"__invert__": "integer_or_boolean",

0 commit comments

Comments
 (0)