Skip to content

Commit f1dc089

Browse files
Make shape_tuple return all static shape information
1 parent 456cce1 commit f1dc089

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

aesara/tensor/shape.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import aesara
99
from aesara.gradient import DisconnectedType
1010
from aesara.graph.basic import Apply, Variable
11+
from aesara.graph.type import HasShape
1112
from aesara.link.c.op import COp
1213
from aesara.link.c.params_type import ParamsType
1314
from aesara.misc.safe_asarray import _asarray
@@ -158,18 +159,28 @@ def _get_vector_length_Shape(op, var):
158159

159160

160161
def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]:
161-
"""Get a tuple of symbolic shape values.
162+
r"""Get a tuple of symbolic shape values.
163+
164+
This will return `ScalarConstant`\s for static shape values.
162165
163-
This will return a `ScalarConstant` with the value ``1`` wherever
164-
broadcastable is ``True``.
165166
"""
166-
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
167-
return tuple(
168-
one_at if getattr(sh, "value", sh) == 1 or bcast else sh
169-
for sh, bcast in zip(
170-
shape(x), getattr(x, "broadcastable", (False,) * x.type.ndim)
171-
)
172-
)
167+
if not isinstance(x.type, HasShape):
168+
# We assume/call it a scalar
169+
return ()
170+
171+
res = ()
172+
symbolic_shape = shape(x)
173+
static_shape = x.type.shape
174+
for i in range(x.type.ndim):
175+
shape_val = static_shape[i]
176+
177+
if shape_val is not None:
178+
# TODO: Why not use uint64?
179+
res += (aesara.scalar.ScalarConstant(aesara.scalar.int64, shape_val),)
180+
else:
181+
res += (symbolic_shape[i],)
182+
183+
return res
173184

174185

175186
class Shape_i(COp):

tests/tensor/test_shape.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from aesara.graph.fg import FunctionGraph
1010
from aesara.graph.type import Type
1111
from aesara.misc.safe_asarray import _asarray
12+
from aesara.scalar.basic import ScalarConstant
1213
from aesara.tensor import as_tensor_variable, get_vector_length, row
1314
from aesara.tensor.basic import MakeVector, constant
1415
from aesara.tensor.elemwise import DimShuffle, Elemwise
@@ -22,6 +23,7 @@
2223
reshape,
2324
shape,
2425
shape_i,
26+
shape_tuple,
2527
specify_broadcastable,
2628
specify_shape,
2729
unbroadcast,
@@ -46,6 +48,7 @@
4648
from aesara.tensor.var import TensorVariable
4749
from aesara.typed_list import make_list
4850
from tests import unittest_tools as utt
51+
from tests.graph.utils import MyType2
4952
from tests.tensor.utils import eval_outputs, random
5053
from tests.test_rop import RopLopChecker
5154

@@ -657,3 +660,18 @@ def test_basic(self):
657660
Unbroadcast,
658661
warn=False,
659662
)
663+
664+
665+
def test_shape_tuple():
666+
667+
x = Variable(MyType2(), None, None)
668+
assert shape_tuple(x) == ()
669+
670+
x = tensor(np.float64, shape=(1, 2, None))
671+
res = shape_tuple(x)
672+
assert isinstance(res, tuple)
673+
assert isinstance(res[0], ScalarConstant)
674+
assert res[0].data == 1
675+
assert isinstance(res[1], ScalarConstant)
676+
assert res[1].data == 2
677+
assert not isinstance(res[2], ScalarConstant)

0 commit comments

Comments
 (0)