diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index bad700de8b..7e5d22528a 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -346,7 +346,7 @@ def dimshuffle(self, *pattern): DimShuffle """ - if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)): + if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)): pattern = pattern[0] ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) return ds_op(self) diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index 57e47ce064..2c6f818c30 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -451,6 +451,21 @@ def test_set_item_error(self): with pytest.raises(TypeError, match=msg): x[0] += 5 + def test_transpose(self): + X, _ = self.vars + x, _ = self.vals + + # Turn (2,2) -> (1,2) + X, x = X[1:, :], x[1:, :] + + assert_array_equal(X.transpose(0, 1).eval({X: x}), x.transpose(0, 1)) + assert_array_equal(X.transpose(1, 0).eval({X: x}), x.transpose(1, 0)) + + # Test handing in tuples, lists and np.arrays + equal_computations([X.transpose((1, 0))], [X.transpose(1, 0)]) + equal_computations([X.transpose([1, 0])], [X.transpose(1, 0)]) + equal_computations([X.transpose(np.array([1, 0]))], [X.transpose(1, 0)]) + def test_deprecated_import(): with pytest.warns(