Skip to content

Commit b525753

Browse files
committed
Fix transpose numpy compatibility pymc-devs#1142
1 parent 2a7f3e1 commit b525753

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pytensor/tensor/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def dimshuffle(self, *pattern):
346346
DimShuffle
347347
348348
"""
349-
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
349+
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
350350
pattern = pattern[0]
351351
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
352352
return ds_op(self)

tests/tensor/test_variable.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,21 @@ def test_set_item_error(self):
451451
with pytest.raises(TypeError, match=msg):
452452
x[0] += 5
453453

454+
def test_transpose(self):
455+
X, _ = self.vars
456+
x, _ = self.vals
457+
458+
# Turn (2,2) -> (1,2)
459+
X, x = X[1:, :], x[1:, :]
460+
461+
assert_array_equal(X.transpose(0, 1).eval({X: x}), x.transpose(0, 1))
462+
assert_array_equal(X.transpose(1, 0).eval({X: x}), x.transpose(1, 0))
463+
464+
# Test handing in tuples, lists and np.arrays
465+
equal_computations([X.transpose((1, 0))], [X.transpose(1, 0)])
466+
equal_computations([X.transpose([1, 0])], [X.transpose(1, 0)])
467+
equal_computations([X.transpose(np.array([1, 0]))], [X.transpose(1, 0)])
468+
454469

455470
def test_deprecated_import():
456471
with pytest.warns(

0 commit comments

Comments
 (0)