Skip to content

Commit d92b127

Browse files
committed
MAINT: delegate richcompe to ufuncs
1 parent cc5b586 commit d92b127

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

torch_np/_ndarray.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,28 +119,31 @@ def __str__(self):
119119
### comparisons ###
120120
def __eq__(self, other):
121121
try:
122-
t_other = asarray(other).get
122+
return _ufunc_impl.equal(self, asarray(other))
123123
except RuntimeError:
124124
# Failed to convert other to array: definitely not equal.
125-
# TODO: generalize, delegate to ufuncs
126125
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
127126
return asarray(falsy)
128-
return asarray(self._tensor == asarray(other).get())
129127

130128
def __neq__(self, other):
131-
return asarray(self._tensor != asarray(other).get())
129+
try:
130+
return _ufunc_impl.not_equal(self, asarray(other))
131+
except RuntimeError:
132+
# Failed to convert other to array: definitely not equal.
133+
falsy = torch.full(self.shape, fill_value=True, dtype=bool)
134+
return asarray(falsy)
132135

133136
def __gt__(self, other):
134-
return asarray(self._tensor > asarray(other).get())
137+
return _ufunc_impl.greater(self, asarray(other))
135138

136139
def __lt__(self, other):
137-
return asarray(self._tensor < asarray(other).get())
140+
return _ufunc_impl.less(self, asarray(other))
138141

139142
def __ge__(self, other):
140-
return asarray(self._tensor >= asarray(other).get())
143+
return _ufunc_impl.greater_equal(self, asarray(other))
141144

142145
def __le__(self, other):
143-
return asarray(self._tensor <= asarray(other).get())
146+
return _ufunc_impl.less_equal(self, asarray(other))
144147

145148
def __bool__(self):
146149
try:
@@ -270,6 +273,7 @@ def __rand__(self, other):
270273
def __iand__(self, other):
271274
return _ufunc_impl.bitwise_and(self, asarray(other), out=self)
272275

276+
273277
# or, self | other
274278
def __or__(self, other):
275279
return _ufunc_impl.bitwise_or(self, asarray(other))
@@ -280,6 +284,7 @@ def __ror__(self, other):
280284
def __ior__(self, other):
281285
return _ufunc_impl.bitwise_or(self, asarray(other), out=self)
282286

287+
283288
# xor, self ^ other
284289
def __xor__(self, other):
285290
return _ufunc_impl.bitwise_xor(self, asarray(other))
@@ -305,7 +310,6 @@ def __neg__(self):
305310
return _ufunc_impl.negative(self)
306311

307312

308-
309313
### methods to match namespace functions
310314

311315
def squeeze(self, axis=None):

0 commit comments

Comments
 (0)