Skip to content

Commit 1c2bc8f

Browse files
committed
Remove rarely used shape_i helpers
1 parent d9b3924 commit 1c2bc8f

File tree

4 files changed

+7
-23
lines changed

4 files changed

+7
-23
lines changed

pytensor/tensor/rewriting/shape.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
Shape_i,
4343
SpecifyShape,
4444
Unbroadcast,
45-
shape_i,
4645
specify_shape,
4746
unbroadcast,
4847
)
@@ -1060,7 +1059,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
10601059
# Replace `NoneConst` by `shape_i`
10611060
for i, sh in enumerate(shape):
10621061
if NoneConst.equals(sh):
1063-
shape[i] = shape_i(x, i, fgraph)
1062+
shape[i] = x.shape[i]
10641063

10651064
return [stack(shape).astype(np.int64)]
10661065

pytensor/tensor/shape.py

-10
Original file line numberDiff line numberDiff line change
@@ -363,16 +363,6 @@ def recur(node):
363363
return shape(var)[i]
364364

365365

366-
def shape_i_op(i):
367-
key = i
368-
if key not in shape_i_op.cache:
369-
shape_i_op.cache[key] = Shape_i(i)
370-
return shape_i_op.cache[key]
371-
372-
373-
shape_i_op.cache = {} # type: ignore
374-
375-
376366
def register_shape_i_c_code(typ, code, check_input, version=()):
377367
"""
378368
Tell Shape_i how to generate C code for an PyTensor Type.

pytensor/tensor/subtensor.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pytensor.tensor.elemwise import DimShuffle
3939
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
4040
from pytensor.tensor.math import clip
41-
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
41+
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
4242
from pytensor.tensor.type import (
4343
TensorType,
4444
bscalar,
@@ -2705,10 +2705,9 @@ def is_bool_index(idx):
27052705
index_shapes = []
27062706
for idx, ishape in zip(indices, ishapes[1:]):
27072707
# Mixed bool indexes are converted to nonzero entries
2708+
shape0_op = Shape_i(0)
27082709
if is_bool_index(idx):
2709-
index_shapes.extend(
2710-
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
2711-
)
2710+
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
27122711
# The `ishapes` entries for `SliceType`s will be None, and
27132712
# we need to give `indexed_result_shape` the actual slices.
27142713
elif isinstance(getattr(idx, "type", None), SliceType):

tests/tensor/test_shape.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88
from pytensor.compile.ops import DeepCopyOp
99
from pytensor.configdefaults import config
1010
from pytensor.graph.basic import Variable, equal_computations
11-
from pytensor.graph.fg import FunctionGraph
1211
from pytensor.graph.replace import clone_replace, vectorize_node
1312
from pytensor.graph.type import Type
1413
from pytensor.misc.safe_asarray import _asarray
1514
from pytensor.scalar.basic import ScalarConstant
1615
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
1716
from pytensor.tensor.basic import MakeVector, constant, stack
1817
from pytensor.tensor.elemwise import DimShuffle, Elemwise
19-
from pytensor.tensor.rewriting.shape import ShapeFeature
2018
from pytensor.tensor.shape import (
2119
Reshape,
2220
Shape,
@@ -26,7 +24,6 @@
2624
_specify_shape,
2725
reshape,
2826
shape,
29-
shape_i,
3027
shape_tuple,
3128
specify_broadcastable,
3229
specify_shape,
@@ -633,13 +630,12 @@ def test_nonstandard_shapes():
633630
tl_shape = shape(tl)
634631
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))
635632

636-
# There's no `FunctionGraph`, so it should return a `Subtensor`
637-
tl_shape_i = shape_i(tl, 0)
633+
# Test specific dim
634+
tl_shape_i = shape(tl)[0]
638635
assert isinstance(tl_shape_i.owner.op, Subtensor)
639636
assert tl_shape_i.get_test_value() == 2
640637

641-
tl_fg = FunctionGraph([a, b], [tl], features=[ShapeFeature()])
642-
tl_shape_i = shape_i(tl, 0, fgraph=tl_fg)
638+
tl_shape_i = Shape_i(0)(tl)
643639
assert not isinstance(tl_shape_i.owner.op, Subtensor)
644640
assert tl_shape_i.get_test_value() == 2
645641

0 commit comments

Comments
 (0)