Skip to content

Commit fd4a4d9

Browse files
committed
TST: undo (some) test skips of test with scalars
1 parent 08cc8ed commit fd4a4d9

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

torch_np/_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,9 @@ def allow_only_single_axis(axis):
125125
if len(axis) != 1:
126126
raise NotImplementedError("does not handle tuple axis")
127127
return axis[0]
128+
129+
130+
def to_tensors(*inputs):
131+
"""Convert all ndarrays from `inputs` to tensors."""
132+
return tuple([value.get() if isinstance(value, ndarray) else value
133+
for value in inputs])

torch_np/_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
193193
if dtype is None:
194194
dtype = _dtypes.default_int_type()
195195
dtype = result_type(start, stop, step, dtype)
196-
197196
torch_dtype = _dtypes.torch_dtype_from(dtype)
197+
start, stop, step = _helpers.to_tensors(start, stop, step)
198+
198199
try:
199200
return asarray(torch.arange(start, stop, step, dtype=torch_dtype))
200201
except RuntimeError:

torch_np/tests/test_function_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,15 @@ def test_arange_booleans(self):
9696
with pytest.raises(TypeError):
9797
np.arange(3, dtype="bool")
9898

99-
@pytest.mark.skip(reason='XXX: python scalars from array scalars')
10099
@pytest.mark.parametrize("which", [0, 1, 2])
101100
def test_error_paths_and_promotion(self, which):
102-
args = [0, 1, 2] # start, stop, and step
101+
args = [0, 10, 2] # start, stop, and step
103102
args[which] = np.float64(2.) # should ensure float64 output
104103
assert np.arange(*args).dtype == np.float64
105104

106105
# Cover stranger error path, test only to achieve code coverage!
107106
args[which] = [None, []]
108-
with pytest.raises(ValueError):
107+
with pytest.raises((ValueError, RuntimeError)):
109108
# Fails discovering start dtype
110109
np.arange(*args)
111110

torch_np/tests/test_shape_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def test_stack():
464464

465465
# 0d input
466466
for input_ in [(1, 2, 3),
467-
### [np.int32(1), np.int32(2), np.int32(3)], # XXX: numpy scalars?
467+
[np.int32(1), np.int32(2), np.int32(3)],
468468
[np.array(1), np.array(2), np.array(3)]]:
469469
assert_array_equal(stack(input_), [1, 2, 3])
470470
# 1d input examples

0 commit comments

Comments
 (0)