Skip to content

Commit af7ed24

Browse files
committed
Faster infer_static_shape
1 parent 189ba03 commit af7ed24

File tree

7 files changed

+98
-23
lines changed

7 files changed

+98
-23
lines changed

pytensor/tensor/basic.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from pytensor import compile, config, printing
2323
from pytensor import scalar as aes
2424
from pytensor.gradient import DisconnectedType, grad_undefined
25+
from pytensor.graph import RewriteDatabaseQuery
2526
from pytensor.graph.basic import Apply, Constant, Variable
2627
from pytensor.graph.fg import FunctionGraph
2728
from pytensor.graph.op import Op
28-
from pytensor.graph.rewriting.utils import rewrite_graph
29+
from pytensor.graph.rewriting.db import EquilibriumDB
2930
from pytensor.graph.type import HasShape, Type
3031
from pytensor.link.c.op import COp
3132
from pytensor.link.c.params_type import ParamsType
@@ -1356,6 +1357,45 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
13561357
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
13571358

13581359

1360+
class CachedEquilibrimDB(EquilibriumDB):
1361+
"""A subclass of EquilibriumDB that allows caching of a default query for faster reuse."""
1362+
1363+
def __init__(self, default_query):
1364+
super().__init__()
1365+
self._default_query = default_query
1366+
self._cached_default_query = None
1367+
1368+
def register(self, *args, **kwargs):
1369+
# If new rewrites are registered, the default cached query is void
1370+
self.cached_default_query = None
1371+
super().register(*args, **kwargs)
1372+
1373+
@property
1374+
def default_query(self):
1375+
if self._cached_default_query is None:
1376+
self._cached_default_query = self.query(self._default_query)
1377+
return self._cached_default_query
1378+
1379+
1380+
infer_shape_db = CachedEquilibrimDB(
1381+
default_query=RewriteDatabaseQuery(include=("infer_shape",))
1382+
)
1383+
1384+
1385+
def register_infer_shape(rewrite, *tags, **kwargs):
1386+
if isinstance(rewrite, str):
1387+
1388+
def register(inner_lopt):
1389+
return register_infer_shape(inner_lopt, rewrite, *tags, **kwargs)
1390+
1391+
return register
1392+
else:
1393+
name = kwargs.pop("name", None) or rewrite.__name__
1394+
1395+
infer_shape_db.register(name, rewrite, *tags, "infer_shape", **kwargs)
1396+
return rewrite
1397+
1398+
13591399
def infer_static_shape(
13601400
shape: Union[Variable, Sequence[Union[Variable, int]]]
13611401
) -> tuple[Sequence["TensorLike"], Sequence[Optional[int]]]:
@@ -1390,14 +1430,16 @@ def check_type(s):
13901430

13911431
raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
13921432

1393-
sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
1433+
sh = folded_shape = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
1434+
1435+
if not all(isinstance(s, Constant) for s in folded_shape):
1436+
shape_fg = FunctionGraph(outputs=sh, features=[ShapeFeature()], clone=True)
1437+
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
1438+
infer_shape_db.default_query.rewrite(shape_fg)
1439+
if not all(isinstance(s, Constant) for s in shape_fg.outputs):
1440+
topo_constant_folding.rewrite(shape_fg)
1441+
folded_shape = shape_fg.outputs
13941442

1395-
shape_fg = FunctionGraph(
1396-
outputs=sh,
1397-
features=[ShapeFeature()],
1398-
clone=True,
1399-
)
1400-
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
14011443
static_shape = tuple(
14021444
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
14031445
)

pytensor/tensor/rewriting/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
get_underlying_scalar_constant_value,
5959
join,
6060
ones_like,
61+
register_infer_shape,
6162
switch,
6263
tensor_copy,
6364
zeros,
@@ -420,6 +421,7 @@ def local_fill_to_alloc(fgraph, node):
420421
)
421422

422423

424+
@register_infer_shape
423425
@register_canonicalize("fast_compile", "shape_unsafe")
424426
@register_useless("shape_unsafe")
425427
@node_rewriter([fill])
@@ -441,6 +443,7 @@ def local_useless_fill(fgraph, node):
441443
return [v]
442444

443445

446+
@register_infer_shape
444447
@register_specialize("shape_unsafe")
445448
@register_stabilize("shape_unsafe")
446449
@register_canonicalize("shape_unsafe")
@@ -530,6 +533,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
530533
)
531534

532535

536+
@register_infer_shape
533537
@register_useless
534538
@register_canonicalize("fast_compile")
535539
@register_specialize
@@ -806,6 +810,7 @@ def local_remove_all_assert(fgraph, node):
806810
)
807811

808812

813+
@register_infer_shape
809814
@register_specialize
810815
@register_canonicalize
811816
@register_useless
@@ -826,6 +831,7 @@ def local_join_1(fgraph, node):
826831

827832

828833
# TODO: merge in local_useless_join
834+
@register_infer_shape
829835
@register_useless
830836
@register_specialize
831837
@register_canonicalize
@@ -1066,6 +1072,7 @@ def local_merge_switch_same_cond(fgraph, node):
10661072
]
10671073

10681074

1075+
@register_infer_shape
10691076
@register_useless
10701077
@register_canonicalize
10711078
@register_specialize
@@ -1149,6 +1156,7 @@ def constant_folding(fgraph, node):
11491156
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
11501157

11511158

1159+
@register_infer_shape
11521160
@register_canonicalize("fast_compile")
11531161
@register_useless("fast_compile")
11541162
@node_rewriter(None)
@@ -1157,6 +1165,7 @@ def local_view_op(fgraph, node):
11571165
return node.inputs
11581166

11591167

1168+
@register_infer_shape
11601169
@register_useless
11611170
@register_canonicalize
11621171
@register_stabilize

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
extract_constant,
3333
get_underlying_scalar_constant_value,
3434
ones_like,
35+
register_infer_shape,
3536
switch,
3637
zeros_like,
3738
)
@@ -1745,6 +1746,7 @@ def local_reduce_join(fgraph, node):
17451746
return [ret]
17461747

17471748

1749+
@register_infer_shape
17481750
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
17491751
@register_useless("local_cut_useless_reduce")
17501752
@node_rewriter(ALL_REDUCE)

pytensor/tensor/rewriting/shape.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
constant,
2626
extract_constant,
2727
get_underlying_scalar_constant_value,
28+
register_infer_shape,
2829
stack,
2930
)
3031
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -964,6 +965,7 @@ def local_reshape_lift(fgraph, node):
964965
return [e]
965966

966967

968+
@register_infer_shape
967969
@register_useless
968970
@register_canonicalize
969971
@node_rewriter([SpecifyShape])
@@ -990,6 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
990992
return [specify_shape(inner_obj, shape)]
991993

992994

995+
@register_infer_shape
996+
@node_rewriter([Shape])
997+
def local_shape_ground(fgraph, node):
998+
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
999+
[x] = node.inputs
1000+
static_shape = x.type.shape
1001+
if not any(dim is None for dim in static_shape):
1002+
return [stack([constant(dim, dtype="int64") for dim in static_shape])]
1003+
1004+
1005+
@register_infer_shape
9931006
@register_useless
9941007
@register_canonicalize
9951008
@node_rewriter([Shape])
@@ -1014,6 +1027,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
10141027
return [stack(shape).astype(np.int64)]
10151028

10161029

1030+
@register_infer_shape
10171031
@register_canonicalize
10181032
@register_specialize
10191033
@node_rewriter([SpecifyShape])
@@ -1060,6 +1074,7 @@ def local_specify_shape_lift(fgraph, node):
10601074
return new_out
10611075

10621076

1077+
@register_infer_shape
10631078
@register_useless
10641079
@register_canonicalize
10651080
@node_rewriter([Shape_i])
@@ -1079,6 +1094,7 @@ def local_Shape_i_ground(fgraph, node):
10791094
return [as_tensor_variable(s_val, dtype=np.int64)]
10801095

10811096

1097+
@register_infer_shape
10821098
@register_specialize
10831099
@register_canonicalize
10841100
@node_rewriter([Shape])

pytensor/tensor/rewriting/subtensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
concatenate,
2727
extract_constant,
2828
get_underlying_scalar_constant_value,
29+
register_infer_shape,
2930
switch,
3031
)
3132
from pytensor.tensor.elemwise import Elemwise
@@ -328,6 +329,7 @@ def local_subtensor_of_dot(fgraph, node):
328329
return [r]
329330

330331

332+
@register_infer_shape
331333
@register_useless
332334
@register_canonicalize
333335
@register_specialize
@@ -599,6 +601,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
599601
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
600602

601603

604+
@register_infer_shape
602605
@register_useless
603606
@register_canonicalize
604607
@register_specialize
@@ -707,6 +710,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
707710
return
708711

709712

713+
@register_infer_shape
710714
@register_specialize
711715
@register_canonicalize("fast_compile")
712716
@register_useless
@@ -785,6 +789,7 @@ def local_subtensor_make_vector(fgraph, node):
785789
pass
786790

787791

792+
@register_infer_shape
788793
@register_useless
789794
@register_canonicalize
790795
@register_specialize
@@ -1461,6 +1466,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
14611466
return [r2]
14621467

14631468

1469+
@register_infer_shape
14641470
@register_specialize
14651471
@register_stabilize
14661472
@register_canonicalize

pytensor/tensor/shape.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import warnings
22
from numbers import Number
33
from textwrap import dedent
4-
from typing import Union
4+
from typing import Union, cast
55

66
import numpy as np
77

88
import pytensor
99
from pytensor.gradient import DisconnectedType
10+
from pytensor.graph import Op
1011
from pytensor.graph.basic import Apply, Variable
1112
from pytensor.graph.replace import _vectorize_node
1213
from pytensor.graph.type import HasShape
@@ -145,14 +146,14 @@ def c_code_cache_version(self):
145146
def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
146147
"""Return the shape of `x`."""
147148
if not isinstance(x, Variable):
148-
x = at.as_tensor_variable(x)
149+
x = at.as_tensor_variable(x) # type: ignore
149150

150-
return _shape(x)
151+
return cast(Variable, _shape(x))
151152

152153

153-
@_get_vector_length.register(Shape)
154-
def _get_vector_length_Shape(op, var):
155-
return var.owner.inputs[0].type.ndim
154+
@_get_vector_length.register(Shape) # type: ignore
155+
def _get_vector_length_Shape(op: Op, var: TensorVariable) -> int:
156+
return cast(int, var.owner.inputs[0].type.ndim)
156157

157158

158159
@_vectorize_node.register(Shape)
@@ -181,7 +182,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
181182
# We assume/call it a scalar
182183
return ()
183184

184-
res = ()
185+
res: tuple[Variable, ...] = ()
185186
symbolic_shape = shape(x)
186187
static_shape = x.type.shape
187188
for i in range(x.type.ndim):
@@ -191,7 +192,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
191192
# TODO: Why not use uint64?
192193
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
193194
else:
194-
res += (symbolic_shape[i],)
195+
res += (symbolic_shape[i],) # type: ignore
195196

196197
return res
197198

@@ -366,7 +367,7 @@ def shape_i_op(i):
366367
return shape_i_op.cache[key]
367368

368369

369-
shape_i_op.cache = {}
370+
shape_i_op.cache = {} # type: ignore
370371

371372

372373
def register_shape_i_c_code(typ, code, check_input, version=()):
@@ -578,7 +579,7 @@ def specify_shape(
578579

579580
# If the specified shape is already encoded in the input static shape, do nothing
580581
# This ignores PyTensor constants in shape
581-
x = at.as_tensor_variable(x)
582+
x = at.as_tensor_variable(x) # type: ignore
582583
new_shape_info = any(
583584
s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None
584585
)
@@ -589,10 +590,10 @@ def specify_shape(
589590
return _specify_shape(x, *shape)
590591

591592

592-
@_get_vector_length.register(SpecifyShape)
593-
def _get_vector_length_SpecifyShape(op, var):
593+
@_get_vector_length.register(SpecifyShape) # type: ignore
594+
def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
594595
try:
595-
return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()
596+
return int(at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item())
596597
except NotScalarConstantError:
597598
raise ValueError(f"Length of {var} cannot be determined")
598599

@@ -1104,4 +1105,4 @@ def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> A
11041105
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
11051106
old_axes = op.axes
11061107
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
1107-
return unbroadcast(x, *new_axes).owner
1108+
return cast(Apply, unbroadcast(x, *new_axes).owner)

scripts/mypy-failing.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
2727
pytensor/tensor/random/op.py
2828
pytensor/tensor/random/utils.py
2929
pytensor/tensor/rewriting/basic.py
30-
pytensor/tensor/shape.py
3130
pytensor/tensor/slinalg.py
3231
pytensor/tensor/subtensor.py
3332
pytensor/tensor/type.py

0 commit comments

Comments
 (0)