Skip to content

Commit d1e868f

Browse files
ferrinepymc-devs
authored and
pymc-devs
committed
Add a type guard for intX (pymc-devs#4569)
* add type guard for inX * fix test for pandas * fix posterior test, ints passed for float data Closes pymc-devs#4279
1 parent c4a79ef commit d1e868f

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

pymc/aesaraf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def intX(X):
444444
"""
445445
Convert a aesara tensor or numpy array to aesara.tensor.int32 type.
446446
"""
447+
# check value is already int, do nothing in this case
448+
if (hasattr(X, "dtype") and "int" in str(X.dtype)) or isinstance(X, int):
449+
return X
447450
intX = _conversion_map[aesara.config.floatX]
448451
try:
449452
return X.astype(intX)

pymc/tests/test_aesaraf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def test_convert_observed_data(input_dtype):
244244
assert isinstance(aesara_output, Variable)
245245
npt.assert_allclose(aesara_output.eval(), aesara_graph_input.eval())
246246
intX = pm.aesaraf._conversion_map[aesara.config.floatX]
247-
if dense_input.dtype == intX or dense_input.dtype == aesara.config.floatX:
247+
if "int" in str(dense_input.dtype) or dense_input.dtype == aesara.config.floatX:
248248
assert aesara_output.owner is None # func should not have added new nodes
249249
assert aesara_output.name == input_name
250250
else:
@@ -254,7 +254,8 @@ def test_convert_observed_data(input_dtype):
254254
if "float" in input_dtype:
255255
assert aesara_output.dtype == aesara.config.floatX
256256
else:
257-
assert aesara_output.dtype == intX
257+
# only cast floats, leave ints as is
258+
assert aesara_output.dtype == input_dtype
258259

259260
# Check function behavior with generator data
260261
generator_output = func(square_generator)

pymc/tests/test_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_sample_posterior_predictive_after_set_data(self):
8787
)
8888
# Predict on new data.
8989
with model:
90-
x_test = [5, 6, 9]
90+
x_test = [5.0, 6.0, 9.0]
9191
pm.set_data(new_data={"x": x_test})
9292
y_test = pm.sample_posterior_predictive(trace)
9393

@@ -111,7 +111,7 @@ def test_sample_posterior_predictive_after_set_data_with_coords(self):
111111
)
112112
# Predict on new data.
113113
with model:
114-
x_test = [5, 6]
114+
x_test = [5.0, 6.0]
115115
pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]})
116116
pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
117117

0 commit comments

Comments
 (0)