Skip to content

Commit c2909c9

Browse files
Replace Shape_i with any static shape value and not just 1
1 parent 1a53072 commit c2909c9

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

aesara/tensor/rewriting/shape.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,8 +1020,8 @@ def local_Shape_of_SpecifyShape(fgraph, node):
10201020
@register_useless
10211021
@register_canonicalize
10221022
@node_rewriter([Shape_i])
1023-
def local_Shape_i_of_broadcastable(fgraph, node):
1024-
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
1023+
def local_Shape_i_ground(fgraph, node):
1024+
"""Replace ``shape_i(x, i)`` with ``s`` when ``x.type.shape[i] == s``."""
10251025

10261026
if not isinstance(node.op, Shape_i):
10271027
return False
@@ -1031,8 +1031,9 @@ def local_Shape_i_of_broadcastable(fgraph, node):
10311031
if not isinstance(shape_arg.type, TensorType):
10321032
return False
10331033

1034-
if shape_arg.broadcastable[node.op.i]:
1035-
return [as_tensor_variable(1, dtype=np.int64)]
1034+
s_val = shape_arg.type.shape[node.op.i]
1035+
if s_val is not None:
1036+
return [as_tensor_variable(s_val, dtype=np.int64)]
10361037

10371038

10381039
@register_specialize

tests/tensor/rewriting/test_shape.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,15 +493,15 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
493493
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
494494

495495

496-
def test_local_Shape_i_of_broadcastable():
497-
x = tensor(np.float64, shape=(None, 1))
496+
def test_local_Shape_i_ground():
497+
x = tensor(np.float64, shape=(None, 2))
498498
s = Shape_i(1)(x)
499499

500500
fgraph = FunctionGraph(outputs=[s], clone=False)
501501
_ = rewrite_graph(fgraph, clone=False)
502502

503503
assert x not in fgraph.variables
504-
assert fgraph.outputs[0].data == 1
504+
assert fgraph.outputs[0].data == 2
505505

506506
# A test for a non-`TensorType`
507507
class MyType(Type):

0 commit comments

Comments
 (0)