Skip to content

Commit cf78a4c

Browse files
Split convert_observed_data and apply intX to int generators
Previously, the `GeneratorAdapter` applied `floatX` to float data, but kept the original integer dtypes. `floatX` was then applied to everything by `convert_observed_data`. This refactor changes the handling of integer-valued generator data, such that `intX` is applied, and no `floatX` conversion takes place.
1 parent 914e10f commit cf78a4c

File tree

4 files changed

+60
-35
lines changed

4 files changed

+60
-35
lines changed

docs/source/api/pytensorf.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ PyTensor utils
2222
join_nonshared_inputs
2323
make_shared_replacements
2424
generator
25-
convert_observed_data
25+
convert_generator_data
26+
convert_data

pymc/data.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
import pymc as pm
3939

40-
from pymc.pytensorf import convert_observed_data
40+
from pymc.pytensorf import convert_data
41+
from pymc.vartypes import isgenerator
4142

4243
__all__ = [
4344
"get_data",
@@ -98,7 +99,7 @@ def make_variable(self, gop, name=None):
9899
def __init__(self, generator):
99100
if not pm.vartypes.isgenerator(generator):
100101
raise TypeError("Object should be generator like")
101-
self.test_value = pm.smartfloatX(copy(next(generator)))
102+
self.test_value = pm.smarttypeX(copy(next(generator)))
102103
# make pickling potentially possible
103104
self._yielded_test_value = False
104105
self.gen = generator
@@ -110,7 +111,7 @@ def __next__(self):
110111
self._yielded_test_value = True
111112
return self.test_value
112113
else:
113-
return pm.smartfloatX(copy(next(self.gen)))
114+
return pm.smarttypeX(copy(next(self.gen)))
114115

115116
# python2 generator
116117
next = __next__
@@ -403,9 +404,15 @@ def Data(
403404
)
404405
name = model.name_for(name)
405406

406-
# `convert_observed_data` takes care of parameter `value` and
407-
# transforms it to something digestible for PyTensor.
408-
arr = convert_observed_data(value)
407+
# Transform `value` it to something digestible for PyTensor.
408+
if isgenerator(value):
409+
raise NotImplementedError(
410+
"Generator type data is no longer supported with pm.Data.",
411+
# It messes up InferenceData and can't be the input to a SharedVariable.
412+
)
413+
else:
414+
arr = convert_data(value)
415+
409416
if isinstance(arr, np.ma.MaskedArray):
410417
raise NotImplementedError(
411418
"Masked arrays or arrays with `nan` entries are not supported. "

pymc/pytensorf.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,26 @@
7474
"join_nonshared_inputs",
7575
"make_shared_replacements",
7676
"generator",
77+
"convert_data",
78+
"convert_generator_data",
7779
"convert_observed_data",
7880
"compile_pymc",
7981
]
8082

8183

8284
def convert_observed_data(data) -> np.ndarray | Variable:
8385
"""Convert user provided dataset to accepted formats."""
84-
8586
if isgenerator(data):
86-
return floatX(generator(data))
87+
return convert_generator_data(data)
88+
return convert_data(data)
89+
90+
91+
def convert_generator_data(data) -> TensorVariable:
92+
return generator(data)
8793

94+
95+
def convert_data(data) -> np.ndarray | Variable:
96+
ret: np.ndarray | Variable
8897
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
8998
# typically, but not limited to pandas objects
9099
vals = data.to_numpy()
@@ -123,16 +132,12 @@ def convert_observed_data(data) -> np.ndarray | Variable:
123132
else:
124133
ret = np.asarray(data)
125134

126-
# type handling to enable index variables when data is int:
127-
if hasattr(data, "dtype"):
128-
if "int" in str(data.dtype):
129-
return intX(ret)
130-
# otherwise, assume float:
131-
else:
132-
return floatX(ret)
133-
# needed for uses of this function other than with pm.Data:
134-
else:
135+
# Data without dtype info is converted to float arrays by default.
136+
# This is the most common case for simple examples.
137+
if not hasattr(data, "dtype"):
135138
return floatX(ret)
139+
# Otherwise we only convert the precision.
140+
return smarttypeX(ret)
136141

137142

138143
@_as_tensor_variable.register(pd.Series)

tests/test_pytensorf.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@
3838
from pymc.exceptions import NotConstantValueError
3939
from pymc.logprob.utils import ParameterValueError
4040
from pymc.pytensorf import (
41+
GeneratorOp,
4142
collect_default_updates,
4243
compile_pymc,
4344
constant_fold,
44-
convert_observed_data,
45+
convert_data,
46+
convert_generator_data,
4547
extract_obs_data,
4648
hessian,
4749
hessian_diag,
@@ -188,9 +190,9 @@ def test_extract_obs_data():
188190

189191

190192
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
191-
def test_convert_observed_data(input_dtype):
193+
def test_convert_data(input_dtype):
192194
"""
193-
Ensure that convert_observed_data returns the dense array, masked array,
195+
Ensure that convert_data returns the dense array, masked array,
194196
graph variable, TensorVariable, or sparse matrix as appropriate.
195197
"""
196198
# Create the various inputs to the function
@@ -206,12 +208,8 @@ def test_convert_observed_data(input_dtype):
206208
missing_pandas_input = pd.DataFrame(missing_numpy_input)
207209
masked_array_input = ma.array(dense_input, mask=(np.mod(dense_input, 2) == 0))
208210

209-
# Create a generator object. Apparently the generator object needs to
210-
# yield numpy arrays.
211-
square_generator = (np.array([i**2], dtype=int) for i in range(100))
212-
213211
# Alias the function to be tested
214-
func = convert_observed_data
212+
func = convert_data
215213

216214
#####
217215
# Perform the various tests
@@ -255,21 +253,35 @@ def test_convert_observed_data(input_dtype):
255253
else:
256254
assert pytensor_output.dtype == intX
257255

258-
# Check function behavior with generator data
259-
generator_output = func(square_generator)
260256

261-
# Output is wrapped with `pm.floatX`, and this unwraps
262-
wrapped = generator_output.owner.inputs[0]
263-
# Make sure the returned object has .set_gen and .set_default methods
264-
assert hasattr(wrapped, "set_gen")
265-
assert hasattr(wrapped, "set_default")
257+
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
258+
def test_convert_generator_data(input_dtype):
259+
# Create a generator object producing NumPy arrays with the intended dtype.
260+
# This is required to infer the correct dtype.
261+
square_generator = (np.array([i**2], dtype=input_dtype) for i in range(100))
262+
263+
# Output is NOT wrapped with `pm.floatX`/`intX`,
264+
# but produced from calling a special Op.
265+
result = convert_generator_data(square_generator)
266+
apply = result.owner
267+
op = apply.op
266268
# Make sure the returned object is an PyTensor TensorVariable
267-
assert isinstance(wrapped, TensorVariable)
269+
assert isinstance(result, TensorVariable)
270+
assert isinstance(op, GeneratorOp), f"It's a {type(apply)}"
271+
# There are no inputs - because it generates...
272+
assert apply.inputs == []
273+
274+
# Evaluation results should have the correct* dtype!
275+
# (*intX/floatX will be enforced!)
276+
evaled = result.eval()
277+
expected_dtype = pm.smarttypeX(np.array(1, dtype=input_dtype)).dtype
278+
assert result.type.dtype == expected_dtype
279+
assert evaled.dtype == np.dtype(expected_dtype)
268280

269281

270282
def test_pandas_to_array_pandas_index():
271283
data = pd.Index([1, 2, 3])
272-
result = convert_observed_data(data)
284+
result = convert_data(data)
273285
expected = np.array([1, 2, 3])
274286
np.testing.assert_array_equal(result, expected)
275287

0 commit comments

Comments
 (0)