diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 0fdb53acfd..bcb1bfa00b 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -444,6 +444,9 @@ def intX(X): """ Convert a aesara tensor or numpy array to aesara.tensor.int32 type. """ + # check value is already int, do nothing in this case + if (hasattr(X, "dtype") and "int" in str(X.dtype)) or isinstance(X, int): + return X intX = _conversion_map[aesara.config.floatX] try: return X.astype(intX) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index fc57448584..a6db6b70c8 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -878,7 +878,11 @@ def dist(cls, N, k, n, *args, **kwargs): return super().dist([good, bad, n], *args, **kwargs) def moment(rv, size, good, bad, n): - N, k = good + bad, good + # Cast to float because the intX can be int8 + # which could trigger an integer overflow below. + n = floatX(n) + k = floatX(good) + N = k + floatX(bad) mode = at.floor((n + 1) * (k + 1) / (N + 2)) if not rv_size_is_none(size): mode = at.full(size, mode) @@ -1014,6 +1018,8 @@ def dist(cls, lower, upper, *args, **kwargs): return super().dist([lower, upper], **kwargs) def moment(rv, size, lower, upper): + upper = floatX(upper) + lower = floatX(lower) mode = at.maximum(at.floor((upper + lower) / 2.0), lower) if not rv_size_is_none(size): mode = at.full(size, mode) diff --git a/pymc/tests/distributions/util.py b/pymc/tests/distributions/util.py index 0a501da4e8..a6adba6097 100644 --- a/pymc/tests/distributions/util.py +++ b/pymc/tests/distributions/util.py @@ -579,7 +579,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): assert moment.shape == expected.shape assert expected.shape == random_draw.shape - assert np.allclose(moment, expected) + np.testing.assert_allclose(moment, expected, atol=1e-10) if check_finite_logp: logp_moment = ( diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index f627d932fa..05a9137397 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -244,7 +244,7 @@ def test_convert_observed_data(input_dtype): assert isinstance(aesara_output, Variable) npt.assert_allclose(aesara_output.eval(), aesara_graph_input.eval()) intX = pm.aesaraf._conversion_map[aesara.config.floatX] - if dense_input.dtype == intX or dense_input.dtype == aesara.config.floatX: + if "int" in str(dense_input.dtype) or dense_input.dtype == aesara.config.floatX: assert aesara_output.owner is None # func should not have added new nodes assert aesara_output.name == input_name else: @@ -254,7 +254,8 @@ def test_convert_observed_data(input_dtype): if "float" in input_dtype: assert aesara_output.dtype == aesara.config.floatX else: - assert aesara_output.dtype == intX + # only cast floats, leave ints as is + assert aesara_output.dtype == input_dtype # Check function behavior with generator data generator_output = func(square_generator) diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index 52b18705ba..4536786e88 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -87,7 +87,7 @@ def test_sample_posterior_predictive_after_set_data(self): ) # Predict on new data. with model: - x_test = [5, 6, 9] + x_test = [5.0, 6.0, 9.0] pm.set_data(new_data={"x": x_test}) y_test = pm.sample_posterior_predictive(trace) @@ -111,7 +111,7 @@ def test_sample_posterior_predictive_after_set_data_with_coords(self): ) # Predict on new data. with model: - x_test = [5, 6] + x_test = [5.0, 6.0] pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]}) pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)