File tree 3 files changed +8
-4
lines changed 3 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -444,6 +444,9 @@ def intX(X):
444
444
"""
445
445
Convert a aesara tensor or numpy array to aesara.tensor.int32 type.
446
446
"""
447
+ # check value is already int, do nothing in this case
448
+ if (hasattr (X , "dtype" ) and "int" in str (X .dtype )) or isinstance (X , int ):
449
+ return X
447
450
intX = _conversion_map [aesara .config .floatX ]
448
451
try :
449
452
return X .astype (intX )
Original file line number Diff line number Diff line change @@ -244,7 +244,7 @@ def test_convert_observed_data(input_dtype):
244
244
assert isinstance (aesara_output , Variable )
245
245
npt .assert_allclose (aesara_output .eval (), aesara_graph_input .eval ())
246
246
intX = pm .aesaraf ._conversion_map [aesara .config .floatX ]
247
- if dense_input .dtype == intX or dense_input .dtype == aesara .config .floatX :
247
+ if "int" in str ( dense_input .dtype ) or dense_input .dtype == aesara .config .floatX :
248
248
assert aesara_output .owner is None # func should not have added new nodes
249
249
assert aesara_output .name == input_name
250
250
else :
@@ -254,7 +254,8 @@ def test_convert_observed_data(input_dtype):
254
254
if "float" in input_dtype :
255
255
assert aesara_output .dtype == aesara .config .floatX
256
256
else :
257
- assert aesara_output .dtype == intX
257
+ # only cast floats, leave ints as is
258
+ assert aesara_output .dtype == input_dtype
258
259
259
260
# Check function behavior with generator data
260
261
generator_output = func (square_generator )
Original file line number Diff line number Diff line change @@ -87,7 +87,7 @@ def test_sample_posterior_predictive_after_set_data(self):
87
87
)
88
88
# Predict on new data.
89
89
with model :
90
- x_test = [5 , 6 , 9 ]
90
+ x_test = [5.0 , 6.0 , 9.0 ]
91
91
pm .set_data (new_data = {"x" : x_test })
92
92
y_test = pm .sample_posterior_predictive (trace )
93
93
@@ -111,7 +111,7 @@ def test_sample_posterior_predictive_after_set_data_with_coords(self):
111
111
)
112
112
# Predict on new data.
113
113
with model :
114
- x_test = [5 , 6 ]
114
+ x_test = [5.0 , 6.0 ]
115
115
pm .set_data (new_data = {"x" : x_test }, coords = {"obs_id" : ["a" , "b" ]})
116
116
pm .sample_posterior_predictive (idata , extend_inferencedata = True , predictions = True )
117
117
You can’t perform that action at this time.
0 commit comments