Skip to content

Commit a96c497

Browse files
committed
More FFT multi-device
1 parent cca1785 commit a96c497

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

array_api_strict/_array_object.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -859,12 +859,13 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array:
859859
"""
860860
Performs the operation __rshift__.
861861
"""
862+
other = self._check_device(other)
862863
other = self._check_allowed_dtypes(other, "integer", "__rshift__")
863864
if other is NotImplemented:
864865
return other
865866
self, other = self._normalize_two_args(self, other)
866867
res = self._array.__rshift__(other._array)
867-
return self.__class__._new(res)
868+
return self.__class__._new(res, device=self.device)
868869

869870
def __setitem__(
870871
self,
@@ -889,41 +890,45 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:
889890
"""
890891
Performs the operation __sub__.
891892
"""
893+
other = self._check_device(other)
892894
other = self._check_allowed_dtypes(other, "numeric", "__sub__")
893895
if other is NotImplemented:
894896
return other
895897
self, other = self._normalize_two_args(self, other)
896898
res = self._array.__sub__(other._array)
897-
return self.__class__._new(res)
899+
return self.__class__._new(res, device=self.device)
898900

899901
# PEP 484 requires int to be a subtype of float, but __truediv__ should
900902
# not accept int.
901903
def __truediv__(self: Array, other: Union[float, Array], /) -> Array:
902904
"""
903905
Performs the operation __truediv__.
904906
"""
907+
other = self._check_device(other)
905908
other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
906909
if other is NotImplemented:
907910
return other
908911
self, other = self._normalize_two_args(self, other)
909912
res = self._array.__truediv__(other._array)
910-
return self.__class__._new(res)
913+
return self.__class__._new(res, device=self.device)
911914

912915
def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array:
913916
"""
914917
Performs the operation __xor__.
915918
"""
919+
other = self._check_device(other)
916920
other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
917921
if other is NotImplemented:
918922
return other
919923
self, other = self._normalize_two_args(self, other)
920924
res = self._array.__xor__(other._array)
921-
return self.__class__._new(res)
925+
return self.__class__._new(res, device=self.device)
922926

923927
def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array:
924928
"""
925929
Performs the operation __iadd__.
926930
"""
931+
other = self._check_device(other)
927932
other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
928933
if other is NotImplemented:
929934
return other
@@ -934,17 +939,19 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array:
934939
"""
935940
Performs the operation __radd__.
936941
"""
942+
other = self._check_device(other)
937943
other = self._check_allowed_dtypes(other, "numeric", "__radd__")
938944
if other is NotImplemented:
939945
return other
940946
self, other = self._normalize_two_args(self, other)
941947
res = self._array.__radd__(other._array)
942-
return self.__class__._new(res)
948+
return self.__class__._new(res, device=self.device)
943949

944950
def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array:
945951
"""
946952
Performs the operation __iand__.
947953
"""
954+
other = self._check_device(other)
948955
other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
949956
if other is NotImplemented:
950957
return other
@@ -955,17 +962,19 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array:
955962
"""
956963
Performs the operation __rand__.
957964
"""
965+
other = self._check_device(other)
958966
other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
959967
if other is NotImplemented:
960968
return other
961969
self, other = self._normalize_two_args(self, other)
962970
res = self._array.__rand__(other._array)
963-
return self.__class__._new(res)
971+
return self.__class__._new(res, device=self.device)
964972

965973
def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
966974
"""
967975
Performs the operation __ifloordiv__.
968976
"""
977+
other = self._check_device(other)
969978
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
970979
if other is NotImplemented:
971980
return other
@@ -976,17 +985,19 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
976985
"""
977986
Performs the operation __rfloordiv__.
978987
"""
988+
other = self._check_device(other)
979989
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
980990
if other is NotImplemented:
981991
return other
982992
self, other = self._normalize_two_args(self, other)
983993
res = self._array.__rfloordiv__(other._array)
984-
return self.__class__._new(res)
994+
return self.__class__._new(res, device=self.device)
985995

986996
def __ilshift__(self: Array, other: Union[int, Array], /) -> Array:
987997
"""
988998
Performs the operation __ilshift__.
989999
"""
1000+
other = self._check_device(other)
9901001
other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
9911002
if other is NotImplemented:
9921003
return other
@@ -997,17 +1008,19 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array:
9971008
"""
9981009
Performs the operation __rlshift__.
9991010
"""
1011+
other = self._check_device(other)
10001012
other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
10011013
if other is NotImplemented:
10021014
return other
10031015
self, other = self._normalize_two_args(self, other)
10041016
res = self._array.__rlshift__(other._array)
1005-
return self.__class__._new(res)
1017+
return self.__class__._new(res, device=self.device)
10061018

10071019
def __imatmul__(self: Array, other: Array, /) -> Array:
10081020
"""
10091021
Performs the operation __imatmul__.
10101022
"""
1023+
other = self._check_device(other)
10111024
# matmul is not defined for scalars, but without this, we may get
10121025
# the wrong error message from asarray.
10131026
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")

array_api_strict/_fft.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar
257257
258258
See its docstring for more information.
259259
"""
260-
if device not in ALL_DEVICES:
260+
if device is not None and device not in ALL_DEVICES:
261261
raise ValueError(f"Unsupported device {device!r}")
262262
return Array._new(np.fft.fftfreq(n, d=d), device=device)
263263

@@ -268,7 +268,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A
268268
269269
See its docstring for more information.
270270
"""
271-
if device not in ALL_DEVICES:
271+
if device is not None and device not in ALL_DEVICES:
272272
raise ValueError(f"Unsupported device {device!r}")
273273
return Array._new(np.fft.rfftfreq(n, d=d), device=device)
274274

0 commit comments

Comments
 (0)