|
1 |
| -from itertools import zip_longest, chain |
2 |
| -from typing import Sequence |
| 1 | +from collections.abc import Sequence |
| 2 | +from itertools import chain, zip_longest |
3 | 3 |
|
4 | 4 | from pymc import SymbolicRandomVariable
|
5 | 5 | from pytensor.compile import SharedVariable
|
6 |
| -from pytensor.graph import ancestors, Constant, graph_inputs, Variable |
| 6 | +from pytensor.graph import Constant, Variable, ancestors, graph_inputs |
7 | 7 | from pytensor.graph.basic import io_toposort
|
8 |
| -from pytensor.tensor import TensorVariable, TensorType |
| 8 | +from pytensor.tensor import TensorType, TensorVariable |
9 | 9 | from pytensor.tensor.blockwise import Blockwise
|
10 |
| -from pytensor.tensor.elemwise import DimShuffle, Elemwise, CAReduce |
| 10 | +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
11 | 11 | from pytensor.tensor.random.op import RandomVariable
|
12 | 12 | from pytensor.tensor.rewriting.subtensor import is_full_slice
|
13 | 13 | from pytensor.tensor.shape import Shape
|
14 |
| -from pytensor.tensor.subtensor import Subtensor, get_idx_list, AdvancedSubtensor |
| 14 | +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list |
15 | 15 | from pytensor.tensor.type_other import NoneTypeT
|
16 | 16 |
|
17 | 17 |
|
@@ -58,7 +58,6 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
|
58 | 58 | ]
|
59 | 59 |
|
60 | 60 |
|
61 |
| - |
62 | 61 | def collect_shared_vars(outputs, blockers):
|
63 | 62 | return [
|
64 | 63 | inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
|
@@ -86,18 +85,22 @@ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
|
86 | 85 | return adv_group_axis, adv_group_ndim
|
87 | 86 |
|
88 | 87 |
|
89 |
| -def _broadcast_dims(inputs_dims: Sequence[tuple[tuple[int, ...], ...]]) -> tuple[tuple[int, ...], ...]: |
| 88 | +def _broadcast_dims( |
| 89 | + inputs_dims: Sequence[tuple[tuple[int, ...], ...]], |
| 90 | +) -> tuple[tuple[int, ...], ...]: |
90 | 91 | output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
|
91 | 92 | # Add missing dims
|
92 |
| - inputs_dims = [ |
93 |
| - ((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims |
94 |
| - ] |
| 93 | + inputs_dims = [((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims] |
95 | 94 | # Combine aligned dims
|
96 |
| - output_dims = tuple(tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)) |
| 95 | + output_dims = tuple( |
| 96 | + tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims) |
| 97 | + ) |
97 | 98 | return output_dims
|
98 | 99 |
|
99 | 100 |
|
100 |
| -def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[tuple[int, ...], ...]]: |
| 101 | +def subgraph_dim_connection( |
| 102 | + input_var, other_inputs, output_vars |
| 103 | +) -> list[tuple[tuple[int, ...], ...]]: |
101 | 104 | """Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
|
102 | 105 |
|
103 | 106 | Raises
|
@@ -135,13 +138,16 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
|
135 | 138 | op_batch_ndim = node.op.batch_ndim(node)
|
136 | 139 |
|
137 | 140 | # Collapse all core_dims
|
138 |
| - core_dims = tuple(sorted(chain.from_iterable([i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]))) |
139 |
| - batch_dims = _broadcast_dims( |
140 |
| - tuple( |
141 |
| - input_dims[:op_batch_ndim] |
142 |
| - for input_dims in inputs_dims |
| 141 | + core_dims = tuple( |
| 142 | + sorted( |
| 143 | + chain.from_iterable( |
| 144 | + [i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]] |
| 145 | + ) |
143 | 146 | )
|
144 | 147 | )
|
| 148 | + batch_dims = _broadcast_dims( |
| 149 | + tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims) |
| 150 | + ) |
145 | 151 | # Add batch dims to each output_dims
|
146 | 152 | batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims)
|
147 | 153 | for out in node.outputs:
|
@@ -221,7 +227,7 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
|
221 | 227 | elif value_dim:
|
222 | 228 | # We are trying to partially slice or index a known dimension
|
223 | 229 | raise NotImplementedError(
|
224 |
| - f"Partial slicing or advanced integer indexing of known dimensions not supported" |
| 230 | + "Partial slicing or advanced integer indexing of known dimensions not supported" |
225 | 231 | )
|
226 | 232 | elif isinstance(idx, slice):
|
227 | 233 | # Unknown dimensions kept by partial slice.
|
|
0 commit comments