@@ -859,12 +859,13 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array:
859
859
"""
860
860
Performs the operation __rshift__.
861
861
"""
862
+ other = self ._check_device (other )
862
863
other = self ._check_allowed_dtypes (other , "integer" , "__rshift__" )
863
864
if other is NotImplemented :
864
865
return other
865
866
self , other = self ._normalize_two_args (self , other )
866
867
res = self ._array .__rshift__ (other ._array )
867
- return self .__class__ ._new (res )
868
+ return self .__class__ ._new (res , device = self . device )
868
869
869
870
def __setitem__ (
870
871
self ,
@@ -889,41 +890,45 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:
889
890
"""
890
891
Performs the operation __sub__.
891
892
"""
893
+ other = self ._check_device (other )
892
894
other = self ._check_allowed_dtypes (other , "numeric" , "__sub__" )
893
895
if other is NotImplemented :
894
896
return other
895
897
self , other = self ._normalize_two_args (self , other )
896
898
res = self ._array .__sub__ (other ._array )
897
- return self .__class__ ._new (res )
899
+ return self .__class__ ._new (res , device = self . device )
898
900
899
901
# PEP 484 requires int to be a subtype of float, but __truediv__ should
900
902
# not accept int.
901
903
def __truediv__ (self : Array , other : Union [float , Array ], / ) -> Array :
902
904
"""
903
905
Performs the operation __truediv__.
904
906
"""
907
+ other = self ._check_device (other )
905
908
other = self ._check_allowed_dtypes (other , "floating-point" , "__truediv__" )
906
909
if other is NotImplemented :
907
910
return other
908
911
self , other = self ._normalize_two_args (self , other )
909
912
res = self ._array .__truediv__ (other ._array )
910
- return self .__class__ ._new (res )
913
+ return self .__class__ ._new (res , device = self . device )
911
914
912
915
def __xor__ (self : Array , other : Union [int , bool , Array ], / ) -> Array :
913
916
"""
914
917
Performs the operation __xor__.
915
918
"""
919
+ other = self ._check_device (other )
916
920
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__xor__" )
917
921
if other is NotImplemented :
918
922
return other
919
923
self , other = self ._normalize_two_args (self , other )
920
924
res = self ._array .__xor__ (other ._array )
921
- return self .__class__ ._new (res )
925
+ return self .__class__ ._new (res , device = self . device )
922
926
923
927
def __iadd__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
924
928
"""
925
929
Performs the operation __iadd__.
926
930
"""
931
+ other = self ._check_device (other )
927
932
other = self ._check_allowed_dtypes (other , "numeric" , "__iadd__" )
928
933
if other is NotImplemented :
929
934
return other
@@ -934,17 +939,19 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array:
934
939
"""
935
940
Performs the operation __radd__.
936
941
"""
942
+ other = self ._check_device (other )
937
943
other = self ._check_allowed_dtypes (other , "numeric" , "__radd__" )
938
944
if other is NotImplemented :
939
945
return other
940
946
self , other = self ._normalize_two_args (self , other )
941
947
res = self ._array .__radd__ (other ._array )
942
- return self .__class__ ._new (res )
948
+ return self .__class__ ._new (res , device = self . device )
943
949
944
950
def __iand__ (self : Array , other : Union [int , bool , Array ], / ) -> Array :
945
951
"""
946
952
Performs the operation __iand__.
947
953
"""
954
+ other = self ._check_device (other )
948
955
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__iand__" )
949
956
if other is NotImplemented :
950
957
return other
@@ -955,17 +962,19 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array:
955
962
"""
956
963
Performs the operation __rand__.
957
964
"""
965
+ other = self ._check_device (other )
958
966
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rand__" )
959
967
if other is NotImplemented :
960
968
return other
961
969
self , other = self ._normalize_two_args (self , other )
962
970
res = self ._array .__rand__ (other ._array )
963
- return self .__class__ ._new (res )
971
+ return self .__class__ ._new (res , device = self . device )
964
972
965
973
def __ifloordiv__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
966
974
"""
967
975
Performs the operation __ifloordiv__.
968
976
"""
977
+ other = self ._check_device (other )
969
978
other = self ._check_allowed_dtypes (other , "real numeric" , "__ifloordiv__" )
970
979
if other is NotImplemented :
971
980
return other
@@ -976,17 +985,19 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
976
985
"""
977
986
Performs the operation __rfloordiv__.
978
987
"""
988
+ other = self ._check_device (other )
979
989
other = self ._check_allowed_dtypes (other , "real numeric" , "__rfloordiv__" )
980
990
if other is NotImplemented :
981
991
return other
982
992
self , other = self ._normalize_two_args (self , other )
983
993
res = self ._array .__rfloordiv__ (other ._array )
984
- return self .__class__ ._new (res )
994
+ return self .__class__ ._new (res , device = self . device )
985
995
986
996
def __ilshift__ (self : Array , other : Union [int , Array ], / ) -> Array :
987
997
"""
988
998
Performs the operation __ilshift__.
989
999
"""
1000
+ other = self ._check_device (other )
990
1001
other = self ._check_allowed_dtypes (other , "integer" , "__ilshift__" )
991
1002
if other is NotImplemented :
992
1003
return other
@@ -997,17 +1008,19 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array:
997
1008
"""
998
1009
Performs the operation __rlshift__.
999
1010
"""
1011
+ other = self ._check_device (other )
1000
1012
other = self ._check_allowed_dtypes (other , "integer" , "__rlshift__" )
1001
1013
if other is NotImplemented :
1002
1014
return other
1003
1015
self , other = self ._normalize_two_args (self , other )
1004
1016
res = self ._array .__rlshift__ (other ._array )
1005
- return self .__class__ ._new (res )
1017
+ return self .__class__ ._new (res , device = self . device )
1006
1018
1007
1019
def __imatmul__ (self : Array , other : Array , / ) -> Array :
1008
1020
"""
1009
1021
Performs the operation __imatmul__.
1010
1022
"""
1023
+ other = self ._check_device (other )
1011
1024
# matmul is not defined for scalars, but without this, we may get
1012
1025
# the wrong error message from asarray.
1013
1026
other = self ._check_allowed_dtypes (other , "numeric" , "__imatmul__" )
0 commit comments