Skip to content

Commit 01189d3

Browse files
Create extract_obs_data function
1 parent 985d3cd commit 01189d3

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

pymc3/aesaraf.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from aesara import scalar
2020
from aesara import tensor as aet
2121
from aesara.gradient import grad
22-
from aesara.graph.basic import Apply, graph_inputs
22+
from aesara.graph.basic import Apply, Constant, graph_inputs
2323
from aesara.graph.op import Op
2424
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
2525
from aesara.tensor.elemwise import Elemwise
26+
from aesara.tensor.sharedvar import SharedVariable
27+
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
2628
from aesara.tensor.var import TensorVariable
2729

2830
from pymc3.data import GeneratorAdapter
@@ -48,6 +50,28 @@
4850
]
4951

5052

53+
def extract_obs_data(x: TensorVariable) -> np.ndarray:
54+
"""Extract data observed symbolic variables.
55+
56+
Raises
57+
------
58+
TypeError
59+
60+
"""
61+
if isinstance(x, Constant):
62+
return x.data
63+
if isinstance(x, SharedVariable):
64+
return x.get_value()
65+
if x.owner and isinstance(x.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
66+
array_data = extract_obs_data(x.owner.inputs[0])
67+
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
68+
mask = np.zeros_like(array_data)
69+
mask[mask_idx] = 1
70+
return np.ma.MaskedArray(array_data, mask)
71+
72+
raise TypeError(f"Data cannot be extracted from {x}")
73+
74+
5175
def inputvars(a):
5276
"""
5377
Get the inputs into a aesara variables

pymc3/tests/test_aesaraf.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
import numpy as np
2020
import pytest
2121

22+
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
2223
from aesara.tensor.type import TensorType
2324

24-
from pymc3.aesaraf import _conversion_map, take_along_axis
25+
from pymc3.aesaraf import _conversion_map, extract_obs_data, take_along_axis
2526
from pymc3.vartypes import int_types
2627

2728
FLOATX = str(aesara.config.floatX)
@@ -225,3 +226,49 @@ def test_dtype_failure(self):
225226
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=FLOATX)
226227
with pytest.raises(IndexError):
227228
take_along_axis(arr, indices)
229+
230+
231+
def test_extract_obs_data():
232+
233+
with pytest.raises(TypeError):
234+
extract_obs_data(aet.matrix())
235+
236+
data = np.random.normal(size=(2, 3))
237+
data_at = aet.as_tensor(data)
238+
mask = np.random.binomial(1, 0.5, size=(2, 3)).astype(bool)
239+
240+
for val_at in (data_at, aesara.shared(data)):
241+
res = extract_obs_data(val_at)
242+
243+
assert isinstance(res, np.ndarray)
244+
assert np.array_equal(res, data)
245+
246+
# AdvancedIncSubtensor check
247+
data_m = np.ma.MaskedArray(data, mask)
248+
missing_values = data_at.type()[mask]
249+
constant = aet.as_tensor(data_m.filled())
250+
z_at = aet.set_subtensor(constant[mask.nonzero()], missing_values)
251+
252+
assert isinstance(z_at.owner.op, AdvancedIncSubtensor)
253+
254+
res = extract_obs_data(z_at)
255+
256+
assert isinstance(res, np.ndarray)
257+
assert np.ma.allequal(res, data_m)
258+
259+
# AdvancedIncSubtensor1 check
260+
data = np.random.normal(size=(3,))
261+
data_at = aet.as_tensor(data)
262+
mask = np.random.binomial(1, 0.5, size=(3,)).astype(bool)
263+
264+
data_m = np.ma.MaskedArray(data, mask)
265+
missing_values = data_at.type()[mask]
266+
constant = aet.as_tensor(data_m.filled())
267+
z_at = aet.set_subtensor(constant[mask.nonzero()], missing_values)
268+
269+
assert isinstance(z_at.owner.op, AdvancedIncSubtensor1)
270+
271+
res = extract_obs_data(z_at)
272+
273+
assert isinstance(res, np.ndarray)
274+
assert np.ma.allequal(res, data_m)

0 commit comments

Comments
 (0)