Skip to content

Commit 9c76a8f

Browse files
Add test for dtype kwarg on xspace Ops
1 parent be6ed82 commit 9c76a8f

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tests/tensor/test_extra_ops.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,7 @@ def test_broadcast_arrays():
12821282
["linspace", "logspace", "geomspace"],
12831283
ids=["linspace", "logspace", "geomspace"],
12841284
)
1285+
@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"])
12851286
@pytest.mark.parametrize(
12861287
"start, stop, num_samples, endpoint, axis",
12871288
[
@@ -1294,12 +1295,20 @@ def test_broadcast_arrays():
12941295
(1, np.array([5, 6]), 30, False, -1),
12951296
],
12961297
)
1297-
def test_space_ops(op, start, stop, num_samples, endpoint, axis):
1298+
def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
12981299
pt_func = getattr(pt, op)
12991300
np_func = getattr(np, op)
1300-
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis)
1301+
dtype = dtype + config.floatX[-2:] if dtype is not None else dtype
1302+
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype)
13011303

1302-
numpy_res = np_func(start, stop, num=num_samples, endpoint=endpoint, axis=axis)
1304+
numpy_res = np_func(
1305+
start, stop, num=num_samples, endpoint=endpoint, dtype=dtype, axis=axis
1306+
)
13031307
pytensor_res = function(inputs=[], outputs=z, mode="FAST_COMPILE")()
13041308

1305-
np.testing.assert_allclose(pytensor_res, numpy_res, atol=1e-6, rtol=1e-6)
1309+
np.testing.assert_allclose(
1310+
pytensor_res,
1311+
numpy_res,
1312+
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
1313+
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
1314+
)

0 commit comments

Comments
 (0)