Skip to content

Commit 9341906

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 378dbe4 commit 9341906

File tree

6 files changed

+70
-42
lines changed

6 files changed

+70
-42
lines changed

pymc_experimental/model/marginal/distributions.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
1-
from typing import Sequence
1+
from collections.abc import Sequence
22

33
import numpy as np
44
import pytensor.tensor as pt
5-
from pymc.distributions import (
6-
Bernoulli,
7-
Categorical,
8-
DiscreteUniform,
9-
SymbolicRandomVariable
10-
)
11-
from pymc.logprob.basic import conditional_logp, logp
5+
6+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable
127
from pymc.logprob.abstract import _logprob
8+
from pymc.logprob.basic import conditional_logp, logp
139
from pymc.pytensorf import constant_fold
14-
from pytensor.graph.replace import clone_replace, graph_replace
15-
from pytensor.scan import scan, map as scan_map
1610
from pytensor.compile.mode import Mode
1711
from pytensor.graph import vectorize_graph
18-
from pytensor.tensor import TensorVariable, TensorType
12+
from pytensor.graph.replace import clone_replace, graph_replace
13+
from pytensor.scan import map as scan_map
14+
from pytensor.scan import scan
15+
from pytensor.tensor import TensorType, TensorVariable
1916

2017
from pymc_experimental.distributions import DiscreteMarkovChain
2118

pymc_experimental/model/marginal/graph_analysis.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from itertools import zip_longest, chain
2-
from typing import Sequence
1+
from collections.abc import Sequence
2+
from itertools import chain, zip_longest
33

44
from pymc import SymbolicRandomVariable
55
from pytensor.compile import SharedVariable
6-
from pytensor.graph import ancestors, Constant, graph_inputs, Variable
6+
from pytensor.graph import Constant, Variable, ancestors, graph_inputs
77
from pytensor.graph.basic import io_toposort
8-
from pytensor.tensor import TensorVariable, TensorType
8+
from pytensor.tensor import TensorType, TensorVariable
99
from pytensor.tensor.blockwise import Blockwise
10-
from pytensor.tensor.elemwise import DimShuffle, Elemwise, CAReduce
10+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1111
from pytensor.tensor.random.op import RandomVariable
1212
from pytensor.tensor.rewriting.subtensor import is_full_slice
1313
from pytensor.tensor.shape import Shape
14-
from pytensor.tensor.subtensor import Subtensor, get_idx_list, AdvancedSubtensor
14+
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
1515
from pytensor.tensor.type_other import NoneTypeT
1616

1717

@@ -58,7 +58,6 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
5858
]
5959

6060

61-
6261
def collect_shared_vars(outputs, blockers):
6362
return [
6463
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
@@ -86,18 +85,22 @@ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
8685
return adv_group_axis, adv_group_ndim
8786

8887

89-
def _broadcast_dims(inputs_dims: Sequence[tuple[tuple[int, ...], ...]]) -> tuple[tuple[int, ...], ...]:
88+
def _broadcast_dims(
89+
inputs_dims: Sequence[tuple[tuple[int, ...], ...]],
90+
) -> tuple[tuple[int, ...], ...]:
9091
output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
9192
# Add missing dims
92-
inputs_dims = [
93-
((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims
94-
]
93+
inputs_dims = [((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims]
9594
# Combine aligned dims
96-
output_dims = tuple(tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims))
95+
output_dims = tuple(
96+
tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)
97+
)
9798
return output_dims
9899

99100

100-
def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[tuple[int, ...], ...]]:
101+
def subgraph_dim_connection(
102+
input_var, other_inputs, output_vars
103+
) -> list[tuple[tuple[int, ...], ...]]:
101104
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
102105
103106
Raises
@@ -135,13 +138,16 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
135138
op_batch_ndim = node.op.batch_ndim(node)
136139

137140
# Collapse all core_dims
138-
core_dims = tuple(sorted(chain.from_iterable([i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]])))
139-
batch_dims = _broadcast_dims(
140-
tuple(
141-
input_dims[:op_batch_ndim]
142-
for input_dims in inputs_dims
141+
core_dims = tuple(
142+
sorted(
143+
chain.from_iterable(
144+
[i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]
145+
)
143146
)
144147
)
148+
batch_dims = _broadcast_dims(
149+
tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
150+
)
145151
# Add batch dims to each output_dims
146152
batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims)
147153
for out in node.outputs:
@@ -221,7 +227,7 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
221227
elif value_dim:
222228
# We are trying to partially slice or index a known dimension
223229
raise NotImplementedError(
224-
f"Partial slicing or advanced integer indexing of known dimensions not supported"
230+
"Partial slicing or advanced integer indexing of known dimensions not supported"
225231
)
226232
elif isinstance(idx, slice):
227233
# Unknown dimensions kept by partial slice.

pymc_experimental/model/marginal/marginal_model.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,19 @@
2323
__all__ = ["MarginalModel", "marginalize"]
2424

2525
from pymc_experimental.distributions import DiscreteMarkovChain
26-
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV, DiscreteMarginalMarkovChainRV, \
27-
get_domain_of_finite_discrete_rv, _add_reduce_batch_dependent_logps
28-
from pymc_experimental.model.marginal.graph_analysis import find_conditional_input_rvs, is_conditional_dependent, \
29-
find_conditional_dependent_rvs, subgraph_dim_connection, collect_shared_vars
26+
from pymc_experimental.model.marginal.distributions import (
27+
DiscreteMarginalMarkovChainRV,
28+
FiniteDiscreteMarginalRV,
29+
_add_reduce_batch_dependent_logps,
30+
get_domain_of_finite_discrete_rv,
31+
)
32+
from pymc_experimental.model.marginal.graph_analysis import (
33+
collect_shared_vars,
34+
find_conditional_dependent_rvs,
35+
find_conditional_input_rvs,
36+
is_conditional_dependent,
37+
subgraph_dim_connection,
38+
)
3039

3140
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
3241

@@ -613,4 +622,3 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
613622
marginalized_rvs = marginalization_op(*inputs)
614623
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
615624
return rvs_to_marginalize, marginalized_rvs
616-

tests/model/marginal/test_distributions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
22
import pymc as pm
33
import pytest
4+
45
from pymc.logprob.abstract import _logprob
56
from pytensor import tensor as pt
67
from scipy.stats import norm
78

89
from pymc_experimental import MarginalModel
910
from pymc_experimental.distributions import DiscreteMarkovChain
10-
1111
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV
1212

1313

tests/model/marginal/test_graph_analysis.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import pytensor.tensor as pt
22
import pytest
3+
34
from pymc.distributions import CustomDist
45

56
from pymc_experimental.model.marginal.graph_analysis import subgraph_dim_connection
67

78

89
class TestSubgraphDimConnection:
9-
1010
def test_dimshuffle(self):
1111
inp = pt.zeros(shape=(5, 1, 4, 3))
1212
out1 = pt.matrix_transpose(inp)
@@ -31,11 +31,16 @@ def test_subtensor(self):
3131
assert dims == ((1,),)
3232

3333
invalid_out = inp[0, :1]
34-
with pytest.raises(NotImplementedError, match="Partial slicing of known dimensions not supported"):
34+
with pytest.raises(
35+
NotImplementedError, match="Partial slicing of known dimensions not supported"
36+
):
3537
subgraph_dim_connection(inp, [], [invalid_out])
3638

3739
# If we are slicing a dummy / unknown dimension that's fine
38-
valid_out = pt.expand_dims(inp[:, 0], 1)[0, :1,]
40+
valid_out = pt.expand_dims(inp[:, 0], 1)[
41+
0,
42+
:1,
43+
]
3944
[dims] = subgraph_dim_connection(inp, [], [valid_out])
4045
assert dims == ((), (2,))
4146

@@ -53,11 +58,23 @@ def test_elemwise(self):
5358
# By removing the last dimension, we align the first and the last in the addition
5459
out = inp + inp[:, 0]
5560
[dims] = subgraph_dim_connection(inp, [], [out])
56-
assert dims == ((0,), (0, 1,))
61+
assert dims == (
62+
(0,),
63+
(
64+
0,
65+
1,
66+
),
67+
)
5768

5869
out = inp + inp.T
5970
[dims] = subgraph_dim_connection(inp, [], [out])
60-
assert dims == ((0, 1), (0, 1,))
71+
assert dims == (
72+
(0, 1),
73+
(
74+
0,
75+
1,
76+
),
77+
)
6178

6279
def test_blockwise(self):
6380
inp = pt.zeros(shape=(5, 4, 3, 2))

tests/model/marginal/test_marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from scipy.special import log_softmax, logsumexp
1818
from scipy.stats import halfnorm, norm
1919

20+
from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent
2021
from pymc_experimental.model.marginal.marginal_model import (
2122
MarginalModel,
2223
marginalize,
2324
)
24-
from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent
2525
from tests.utils import equal_computations_up_to_root
2626

2727

0 commit comments

Comments
 (0)