Skip to content

Commit e42d35a

Browse files
michaelosthegericardoV94
authored andcommitted
Use correct dtypes for invalid parameters test
1 parent 255a0c8 commit e42d35a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pymc/testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def scipy_logp_with_scipy_args(**args):
385385
continue
386386

387387
point = valid_params.copy() # Shallow copy should be okay
388-
point[invalid_param] = invalid_edge
388+
point[invalid_param] = np.asarray(
389+
invalid_edge, dtype=paramdomains[invalid_param].dtype
390+
)
389391
with pytest.raises(ParameterValueError):
390392
pymc_logp(**point)
391393
pytest.fail(f"test_params={point}")

0 commit comments

Comments
 (0)