Skip to content

Commit bf0c944

Browse files
authored
TYP: Compact Python scalar types (#149)
* TYP: Compact Python scalar types * Update array_api_strict/_array_object.py
1 parent 17c7c40 commit bf0c944

8 files changed

+62
-75
lines changed

array_api_strict/_array_object.py

+40-40
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __array__(
191191
# NumPy behavior
192192

193193
def _check_allowed_dtypes(
194-
self, other: Array | bool | int | float | complex, dtype_category: str, op: str
194+
self, other: Array | complex, dtype_category: str, op: str
195195
) -> Array:
196196
"""
197197
Helper function for operators to only allow specific input dtypes
@@ -233,7 +233,7 @@ def _check_allowed_dtypes(
233233

234234
return other
235235

236-
def _check_type_device(self, other: Array | bool | int | float | complex) -> None:
236+
def _check_type_device(self, other: Array | complex) -> None:
237237
"""Check that other is either a Python scalar or an array on a device
238238
compatible with the current array.
239239
"""
@@ -245,7 +245,7 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non
245245
raise TypeError(f"Expected Array or Python scalar; got {type(other)}")
246246

247247
# Helper function to match the type promotion rules in the spec
248-
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
248+
def _promote_scalar(self, scalar: complex) -> Array:
249249
"""
250250
Returns a promoted version of a Python scalar appropriate for use with
251251
operations on self.
@@ -539,7 +539,7 @@ def __abs__(self) -> Array:
539539
res = self._array.__abs__()
540540
return self.__class__._new(res, device=self.device)
541541

542-
def __add__(self, other: Array | int | float | complex, /) -> Array:
542+
def __add__(self, other: Array | complex, /) -> Array:
543543
"""
544544
Performs the operation __add__.
545545
"""
@@ -551,7 +551,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
551551
res = self._array.__add__(other._array)
552552
return self.__class__._new(res, device=self.device)
553553

554-
def __and__(self, other: Array | bool | int, /) -> Array:
554+
def __and__(self, other: Array | int, /) -> Array:
555555
"""
556556
Performs the operation __and__.
557557
"""
@@ -648,7 +648,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
648648
# Note: device support is required for this
649649
return self._array.__dlpack_device__()
650650

651-
def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
651+
def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override]
652652
"""
653653
Performs the operation __eq__.
654654
"""
@@ -674,7 +674,7 @@ def __float__(self) -> float:
674674
res = self._array.__float__()
675675
return res
676676

677-
def __floordiv__(self, other: Array | int | float, /) -> Array:
677+
def __floordiv__(self, other: Array | float, /) -> Array:
678678
"""
679679
Performs the operation __floordiv__.
680680
"""
@@ -686,7 +686,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
686686
res = self._array.__floordiv__(other._array)
687687
return self.__class__._new(res, device=self.device)
688688

689-
def __ge__(self, other: Array | int | float, /) -> Array:
689+
def __ge__(self, other: Array | float, /) -> Array:
690690
"""
691691
Performs the operation __ge__.
692692
"""
@@ -738,7 +738,7 @@ def __getitem__(
738738
res = self._array.__getitem__(np_key)
739739
return self._new(res, device=self.device)
740740

741-
def __gt__(self, other: Array | int | float, /) -> Array:
741+
def __gt__(self, other: Array | float, /) -> Array:
742742
"""
743743
Performs the operation __gt__.
744744
"""
@@ -793,7 +793,7 @@ def __iter__(self) -> Iterator[Array]:
793793
# implemented, which implies iteration on 1-D arrays.
794794
return (Array._new(i, device=self.device) for i in self._array)
795795

796-
def __le__(self, other: Array | int | float, /) -> Array:
796+
def __le__(self, other: Array | float, /) -> Array:
797797
"""
798798
Performs the operation __le__.
799799
"""
@@ -817,7 +817,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
817817
res = self._array.__lshift__(other._array)
818818
return self.__class__._new(res, device=self.device)
819819

820-
def __lt__(self, other: Array | int | float, /) -> Array:
820+
def __lt__(self, other: Array | float, /) -> Array:
821821
"""
822822
Performs the operation __lt__.
823823
"""
@@ -842,7 +842,7 @@ def __matmul__(self, other: Array, /) -> Array:
842842
res = self._array.__matmul__(other._array)
843843
return self.__class__._new(res, device=self.device)
844844

845-
def __mod__(self, other: Array | int | float, /) -> Array:
845+
def __mod__(self, other: Array | float, /) -> Array:
846846
"""
847847
Performs the operation __mod__.
848848
"""
@@ -854,7 +854,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
854854
res = self._array.__mod__(other._array)
855855
return self.__class__._new(res, device=self.device)
856856

857-
def __mul__(self, other: Array | int | float | complex, /) -> Array:
857+
def __mul__(self, other: Array | complex, /) -> Array:
858858
"""
859859
Performs the operation __mul__.
860860
"""
@@ -866,7 +866,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
866866
res = self._array.__mul__(other._array)
867867
return self.__class__._new(res, device=self.device)
868868

869-
def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
869+
def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override]
870870
"""
871871
Performs the operation __ne__.
872872
"""
@@ -887,7 +887,7 @@ def __neg__(self) -> Array:
887887
res = self._array.__neg__()
888888
return self.__class__._new(res, device=self.device)
889889

890-
def __or__(self, other: Array | bool | int, /) -> Array:
890+
def __or__(self, other: Array | int, /) -> Array:
891891
"""
892892
Performs the operation __or__.
893893
"""
@@ -908,7 +908,7 @@ def __pos__(self) -> Array:
908908
res = self._array.__pos__()
909909
return self.__class__._new(res, device=self.device)
910910

911-
def __pow__(self, other: Array | int | float | complex, /) -> Array:
911+
def __pow__(self, other: Array | complex, /) -> Array:
912912
"""
913913
Performs the operation __pow__.
914914
"""
@@ -945,7 +945,7 @@ def __setitem__(
945945
| Array
946946
| tuple[int | slice | EllipsisType, ...]
947947
),
948-
value: Array | bool | int | float | complex,
948+
value: Array | complex,
949949
/,
950950
) -> None:
951951
"""
@@ -958,7 +958,7 @@ def __setitem__(
958958
np_key = key._array if isinstance(key, Array) else key
959959
self._array.__setitem__(np_key, asarray(value)._array)
960960

961-
def __sub__(self, other: Array | int | float | complex, /) -> Array:
961+
def __sub__(self, other: Array | complex, /) -> Array:
962962
"""
963963
Performs the operation __sub__.
964964
"""
@@ -972,7 +972,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
972972

973973
# PEP 484 requires int to be a subtype of float, but __truediv__ should
974974
# not accept int.
975-
def __truediv__(self, other: Array | int | float | complex, /) -> Array:
975+
def __truediv__(self, other: Array | complex, /) -> Array:
976976
"""
977977
Performs the operation __truediv__.
978978
"""
@@ -984,7 +984,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
984984
res = self._array.__truediv__(other._array)
985985
return self.__class__._new(res, device=self.device)
986986

987-
def __xor__(self, other: Array | bool | int, /) -> Array:
987+
def __xor__(self, other: Array | int, /) -> Array:
988988
"""
989989
Performs the operation __xor__.
990990
"""
@@ -996,7 +996,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
996996
res = self._array.__xor__(other._array)
997997
return self.__class__._new(res, device=self.device)
998998

999-
def __iadd__(self, other: Array | int | float | complex, /) -> Array:
999+
def __iadd__(self, other: Array | complex, /) -> Array:
10001000
"""
10011001
Performs the operation __iadd__.
10021002
"""
@@ -1007,7 +1007,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
10071007
self._array.__iadd__(other._array)
10081008
return self
10091009

1010-
def __radd__(self, other: Array | int | float | complex, /) -> Array:
1010+
def __radd__(self, other: Array | complex, /) -> Array:
10111011
"""
10121012
Performs the operation __radd__.
10131013
"""
@@ -1019,7 +1019,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
10191019
res = self._array.__radd__(other._array)
10201020
return self.__class__._new(res, device=self.device)
10211021

1022-
def __iand__(self, other: Array | bool | int, /) -> Array:
1022+
def __iand__(self, other: Array | int, /) -> Array:
10231023
"""
10241024
Performs the operation __iand__.
10251025
"""
@@ -1030,7 +1030,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
10301030
self._array.__iand__(other._array)
10311031
return self
10321032

1033-
def __rand__(self, other: Array | bool | int, /) -> Array:
1033+
def __rand__(self, other: Array | int, /) -> Array:
10341034
"""
10351035
Performs the operation __rand__.
10361036
"""
@@ -1042,7 +1042,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
10421042
res = self._array.__rand__(other._array)
10431043
return self.__class__._new(res, device=self.device)
10441044

1045-
def __ifloordiv__(self, other: Array | int | float, /) -> Array:
1045+
def __ifloordiv__(self, other: Array | float, /) -> Array:
10461046
"""
10471047
Performs the operation __ifloordiv__.
10481048
"""
@@ -1053,7 +1053,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10531053
self._array.__ifloordiv__(other._array)
10541054
return self
10551055

1056-
def __rfloordiv__(self, other: Array | int | float, /) -> Array:
1056+
def __rfloordiv__(self, other: Array | float, /) -> Array:
10571057
"""
10581058
Performs the operation __rfloordiv__.
10591059
"""
@@ -1114,7 +1114,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11141114
res = self._array.__rmatmul__(other._array)
11151115
return self.__class__._new(res, device=self.device)
11161116

1117-
def __imod__(self, other: Array | int | float, /) -> Array:
1117+
def __imod__(self, other: Array | float, /) -> Array:
11181118
"""
11191119
Performs the operation __imod__.
11201120
"""
@@ -1124,7 +1124,7 @@ def __imod__(self, other: Array | int | float, /) -> Array:
11241124
self._array.__imod__(other._array)
11251125
return self
11261126

1127-
def __rmod__(self, other: Array | int | float, /) -> Array:
1127+
def __rmod__(self, other: Array | float, /) -> Array:
11281128
"""
11291129
Performs the operation __rmod__.
11301130
"""
@@ -1136,7 +1136,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
11361136
res = self._array.__rmod__(other._array)
11371137
return self.__class__._new(res, device=self.device)
11381138

1139-
def __imul__(self, other: Array | int | float | complex, /) -> Array:
1139+
def __imul__(self, other: Array | complex, /) -> Array:
11401140
"""
11411141
Performs the operation __imul__.
11421142
"""
@@ -1146,7 +1146,7 @@ def __imul__(self, other: Array | int | float | complex, /) -> Array:
11461146
self._array.__imul__(other._array)
11471147
return self
11481148

1149-
def __rmul__(self, other: Array | int | float | complex, /) -> Array:
1149+
def __rmul__(self, other: Array | complex, /) -> Array:
11501150
"""
11511151
Performs the operation __rmul__.
11521152
"""
@@ -1158,7 +1158,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11581158
res = self._array.__rmul__(other._array)
11591159
return self.__class__._new(res, device=self.device)
11601160

1161-
def __ior__(self, other: Array | bool | int, /) -> Array:
1161+
def __ior__(self, other: Array | int, /) -> Array:
11621162
"""
11631163
Performs the operation __ior__.
11641164
"""
@@ -1168,7 +1168,7 @@ def __ior__(self, other: Array | bool | int, /) -> Array:
11681168
self._array.__ior__(other._array)
11691169
return self
11701170

1171-
def __ror__(self, other: Array | bool | int, /) -> Array:
1171+
def __ror__(self, other: Array | int, /) -> Array:
11721172
"""
11731173
Performs the operation __ror__.
11741174
"""
@@ -1180,7 +1180,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
11801180
res = self._array.__ror__(other._array)
11811181
return self.__class__._new(res, device=self.device)
11821182

1183-
def __ipow__(self, other: Array | int | float | complex, /) -> Array:
1183+
def __ipow__(self, other: Array | complex, /) -> Array:
11841184
"""
11851185
Performs the operation __ipow__.
11861186
"""
@@ -1190,7 +1190,7 @@ def __ipow__(self, other: Array | int | float | complex, /) -> Array:
11901190
self._array.__ipow__(other._array)
11911191
return self
11921192

1193-
def __rpow__(self, other: Array | int | float | complex, /) -> Array:
1193+
def __rpow__(self, other: Array | complex, /) -> Array:
11941194
"""
11951195
Performs the operation __rpow__.
11961196
"""
@@ -1225,7 +1225,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12251225
res = self._array.__rrshift__(other._array)
12261226
return self.__class__._new(res, device=self.device)
12271227

1228-
def __isub__(self, other: Array | int | float | complex, /) -> Array:
1228+
def __isub__(self, other: Array | complex, /) -> Array:
12291229
"""
12301230
Performs the operation __isub__.
12311231
"""
@@ -1235,7 +1235,7 @@ def __isub__(self, other: Array | int | float | complex, /) -> Array:
12351235
self._array.__isub__(other._array)
12361236
return self
12371237

1238-
def __rsub__(self, other: Array | int | float | complex, /) -> Array:
1238+
def __rsub__(self, other: Array | complex, /) -> Array:
12391239
"""
12401240
Performs the operation __rsub__.
12411241
"""
@@ -1247,7 +1247,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12471247
res = self._array.__rsub__(other._array)
12481248
return self.__class__._new(res, device=self.device)
12491249

1250-
def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
1250+
def __itruediv__(self, other: Array | complex, /) -> Array:
12511251
"""
12521252
Performs the operation __itruediv__.
12531253
"""
@@ -1257,7 +1257,7 @@ def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
12571257
self._array.__itruediv__(other._array)
12581258
return self
12591259

1260-
def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
1260+
def __rtruediv__(self, other: Array | complex, /) -> Array:
12611261
"""
12621262
Performs the operation __rtruediv__.
12631263
"""
@@ -1269,7 +1269,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12691269
res = self._array.__rtruediv__(other._array)
12701270
return self.__class__._new(res, device=self.device)
12711271

1272-
def __ixor__(self, other: Array | bool | int, /) -> Array:
1272+
def __ixor__(self, other: Array | int, /) -> Array:
12731273
"""
12741274
Performs the operation __ixor__.
12751275
"""
@@ -1279,7 +1279,7 @@ def __ixor__(self, other: Array | bool | int, /) -> Array:
12791279
self._array.__ixor__(other._array)
12801280
return self
12811281

1282-
def __rxor__(self, other: Array | bool | int, /) -> Array:
1282+
def __rxor__(self, other: Array | int, /) -> Array:
12831283
"""
12841284
Performs the operation __rxor__.
12851285
"""

0 commit comments

Comments
 (0)