@@ -156,7 +156,7 @@ def __mul__(self, other):
156
156
157
157
def __rmul__ (self , other ):
158
158
other_tensor = asarray (other ).get ()
159
- return asarray (self ._tensor .__mul__ (other_tensor ))
159
+ return asarray (self ._tensor .__rmul__ (other_tensor ))
160
160
161
161
def __truediv__ (self , other ):
162
162
other_tensor = asarray (other ).get ()
@@ -179,23 +179,13 @@ def squeeze(self, axis=None):
179
179
@axis_out_keepdims_wrapper
180
180
def argmax (self , axis = None , out = None , * , keepdims = NoValue ):
181
181
axis = _helpers .allow_only_single_axis (axis )
182
-
183
- if axis is None :
184
- tensor = torch .argmax (self ._tensor )
185
- else :
186
- tensor = torch .argmax (self ._tensor , axis )
187
-
182
+ tensor = torch .argmax (self ._tensor , axis )
188
183
return tensor
189
184
190
185
@axis_out_keepdims_wrapper
191
186
def argmin (self , axis = None , out = None , * , keepdims = NoValue ):
192
187
axis = _helpers .allow_only_single_axis (axis )
193
-
194
- if axis is None :
195
- tensor = torch .argmin (self ._tensor )
196
- else :
197
- tensor = torch .argmin (self ._tensor , axis )
198
-
188
+ tensor = torch .argmin (self ._tensor , axis )
199
189
return tensor
200
190
201
191
def reshape (self , * shape , order = 'C' ):
0 commit comments