Skip to content

Commit 2e9d502

Browse files
authored
Fix get_vector_length incorrectly returning for shared variable without static shape (#1295)
1 parent 9e603cf commit 2e9d502

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

pytensor/tensor/sharedvar.py

-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44

55
from pytensor.compile import SharedVariable, shared_constructor
6-
from pytensor.tensor import _get_vector_length
76
from pytensor.tensor.type import TensorType
87
from pytensor.tensor.variable import TensorVariable
98

@@ -51,11 +50,6 @@ def zero(self, borrow: bool = False):
5150
self.container.value = 0 * self.container.value
5251

5352

54-
@_get_vector_length.register(TensorSharedVariable)
55-
def _get_vector_length_TensorSharedVariable(var_inst, var):
56-
return len(var.get_value(borrow=True))
57-
58-
5953
@shared_constructor.register(np.ndarray)
6054
def tensor_constructor(
6155
value,

tests/tensor/test_extra_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -965,8 +965,10 @@ def fn(i, d):
965965
f_array_array = fn(indices, shape_array)
966966
np.testing.assert_equal(ref, f_array_array())
967967

968-
# shape given as an PyTensor variable
969-
shape_symb = pytensor.shared(shape_array)
968+
# shape given as a shared PyTensor variable with static shape
969+
shape_symb = pytensor.shared(
970+
shape_array, shape=shape_array.shape, strict=True
971+
)
970972
f_array_symb = fn(indices, shape_symb)
971973
np.testing.assert_equal(ref, f_array_symb())
972974

tests/tensor/test_sharedvar.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ def test_specify_shape_inplace(self):
605605
def test_values_eq(self):
606606
# Test the type.values_eq[_approx] function
607607
dtype = self.dtype
608+
608609
if dtype is None:
609610
dtype = pytensor.config.floatX
610611

@@ -691,9 +692,13 @@ def test_scalar_shared_deprecated():
691692

692693

693694
def test_get_vector_length():
694-
x = pytensor.shared(np.array((2, 3, 4, 5)))
695+
arr = np.array((2, 3, 4, 5))
696+
x = pytensor.shared(arr, shape=arr.shape, strict=True)
695697
assert get_vector_length(x) == 4
696698

699+
with pytest.raises(ValueError):
700+
get_vector_length(pytensor.shared(arr))
701+
697702

698703
def test_shared_masked_array_not_implemented():
699704
x = np.ma.masked_greater(np.array([1, 2, 3, 4]), 3)

0 commit comments

Comments
 (0)