Skip to content

Commit 6a0885d

Browse files
Updated doctests
From numpy PR numpy/numpy#22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail. The doctest for tensor.extra_ops.Unique was failing because the output shape for the inverse indices has changed when axis is None: numpy/numpy#20638
1 parent ff104c0 commit 6a0885d

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

pytensor/tensor/einsum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _general_dot(
269269
270270
.. testoutput::
271271
272-
(3, 4, 2)
272+
(np.int64(3), np.int64(4), np.int64(2))
273273
"""
274274
# Shortcut for non batched case
275275
if not batch_axes[0] and not batch_axes[1]:

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,8 @@ class Unique(Op):
12071207
>>> y = pytensor.tensor.matrix()
12081208
>>> g = pytensor.function([y], Unique(True, True, False)(y))
12091209
>>> g([[1, 1, 1.0], (2, 3, 3.0)])
1210-
[array([1., 2., 3.]), array([0, 3, 4]), array([0, 0, 0, 1, 2, 2])]
1210+
[array([1., 2., 3.]), array([0, 3, 4]), array([[0, 0, 0],
1211+
[1, 2, 2]])]
12111212
12121213
"""
12131214

pytensor/tensor/subtensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,13 +757,15 @@ def get_constant_idx(
757757
Example usage where `v` and `a` are appropriately typed PyTensor variables :
758758
>>> from pytensor.scalar import int64
759759
>>> from pytensor.tensor import matrix
760+
>>> import numpy as np
761+
>>>
760762
>>> v = int64("v")
761763
>>> a = matrix("a")
762764
>>> b = a[v, 1:3]
763765
>>> b.owner.op.idx_list
764766
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
765767
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
766-
[v, slice(1, 3, None)]
768+
[v, slice(np.int64(1), np.int64(3), None)]
767769
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
768770
Traceback (most recent call last):
769771
pytensor.tensor.exceptions.NotScalarConstantError

0 commit comments

Comments
 (0)