File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -1020,8 +1020,8 @@ def local_Shape_of_SpecifyShape(fgraph, node):
1020
1020
@register_useless
1021
1021
@register_canonicalize
1022
1022
@node_rewriter ([Shape_i ])
1023
- def local_Shape_i_of_broadcastable (fgraph , node ):
1024
- """Replace ``shape_i(x, i)`` with ``1 `` when ``x.broadcastable [i]`` is ``True ``."""
1023
+ def local_Shape_i_ground (fgraph , node ):
1024
+ """Replace ``shape_i(x, i)`` with ``s `` when ``x.type.shape [i] == s ``."""
1025
1025
1026
1026
if not isinstance (node .op , Shape_i ):
1027
1027
return False
@@ -1031,8 +1031,9 @@ def local_Shape_i_of_broadcastable(fgraph, node):
1031
1031
if not isinstance (shape_arg .type , TensorType ):
1032
1032
return False
1033
1033
1034
- if shape_arg .broadcastable [node .op .i ]:
1035
- return [as_tensor_variable (1 , dtype = np .int64 )]
1034
+ s_val = shape_arg .type .shape [node .op .i ]
1035
+ if s_val is not None :
1036
+ return [as_tensor_variable (s_val , dtype = np .int64 )]
1036
1037
1037
1038
1038
1039
@register_specialize
Original file line number Diff line number Diff line change @@ -493,15 +493,15 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
493
493
assert not any (isinstance (apply .op , SpecifyShape ) for apply in fgraph .apply_nodes )
494
494
495
495
496
- def test_local_Shape_i_of_broadcastable ():
497
- x = tensor (np .float64 , shape = (None , 1 ))
496
+ def test_local_Shape_i_ground ():
497
+ x = tensor (np .float64 , shape = (None , 2 ))
498
498
s = Shape_i (1 )(x )
499
499
500
500
fgraph = FunctionGraph (outputs = [s ], clone = False )
501
501
_ = rewrite_graph (fgraph , clone = False )
502
502
503
503
assert x not in fgraph .variables
504
- assert fgraph .outputs [0 ].data == 1
504
+ assert fgraph .outputs [0 ].data == 2
505
505
506
506
# A test for a non-`TensorType`
507
507
class MyType (Type ):
You can’t perform that action at this time.
0 commit comments