@@ -1282,6 +1282,7 @@ def test_broadcast_arrays():
1282
1282
["linspace" , "logspace" , "geomspace" ],
1283
1283
ids = ["linspace" , "logspace" , "geomspace" ],
1284
1284
)
1285
+ @pytest .mark .parametrize ("dtype" , [None , "int" , "float" ], ids = [None , "int" , "float" ])
1285
1286
@pytest .mark .parametrize (
1286
1287
"start, stop, num_samples, endpoint, axis" ,
1287
1288
[
@@ -1294,12 +1295,20 @@ def test_broadcast_arrays():
1294
1295
(1 , np .array ([5 , 6 ]), 30 , False , - 1 ),
1295
1296
],
1296
1297
)
1297
- def test_space_ops (op , start , stop , num_samples , endpoint , axis ):
1298
+ def test_space_ops (op , dtype , start , stop , num_samples , endpoint , axis ):
1298
1299
pt_func = getattr (pt , op )
1299
1300
np_func = getattr (np , op )
1300
- z = pt_func (start , stop , num_samples , endpoint = endpoint , axis = axis )
1301
+ dtype = dtype + config .floatX [- 2 :] if dtype is not None else dtype
1302
+ z = pt_func (start , stop , num_samples , endpoint = endpoint , axis = axis , dtype = dtype )
1301
1303
1302
- numpy_res = np_func (start , stop , num = num_samples , endpoint = endpoint , axis = axis )
1304
+ numpy_res = np_func (
1305
+ start , stop , num = num_samples , endpoint = endpoint , dtype = dtype , axis = axis
1306
+ )
1303
1307
pytensor_res = function (inputs = [], outputs = z , mode = "FAST_COMPILE" )()
1304
1308
1305
- np .testing .assert_allclose (pytensor_res , numpy_res , atol = 1e-6 , rtol = 1e-6 )
1309
+ np .testing .assert_allclose (
1310
+ pytensor_res ,
1311
+ numpy_res ,
1312
+ atol = 1e-6 if config .floatX .endswith ("64" ) else 1e-4 ,
1313
+ rtol = 1e-6 if config .floatX .endswith ("64" ) else 1e-4 ,
1314
+ )
0 commit comments