Skip to content

Commit 456cce1

Browse files
Use static shape values in get_vector_length
1 parent c2909c9 commit 456cce1

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

aesara/tensor/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def get_vector_length(v: TensorLike) -> int:
8080
if v.type.ndim != 1:
8181
raise TypeError(f"Argument must be a vector; got {v.type}")
8282

83-
if v.type.broadcastable[0]:
84-
return 1
83+
static_shape: Optional[int] = v.type.shape[0]
84+
if static_shape is not None:
85+
return static_shape
8586

8687
return _get_vector_length(getattr(v.owner, "op", v), v)
8788

tests/tensor/test_basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,8 @@ def test_get_vector_length():
11771177
# Test `Alloc`s
11781178
assert 3 == get_vector_length(alloc(0, 3))
11791179

1180+
assert 5 == get_vector_length(tensor(np.float64, shape=(5,)))
1181+
11801182

11811183
class TestJoinAndSplit:
11821184
# Split is tested by each verify_grad method.

0 commit comments

Comments
 (0)