Skip to content

Commit 45c3a01

Browse files
Updated doctests
From numpy PR numpy/numpy#22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail.
1 parent bce3613 commit 45c3a01

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

pytensor/tensor/einsum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _general_dot(
256256
257257
.. testoutput::
258258
259-
(3, 4, 2)
259+
(np.int64(3), np.int64(4), np.int64(2))
260260
"""
261261
# Shortcut for non batched case
262262
if not batch_axes[0] and not batch_axes[1]:

pytensor/tensor/subtensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,13 +757,15 @@ def get_constant_idx(
757757
Example usage where `v` and `a` are appropriately typed PyTensor variables :
758758
>>> from pytensor.scalar import int64
759759
>>> from pytensor.tensor import matrix
760+
>>> import numpy as np
761+
>>>
760762
>>> v = int64("v")
761763
>>> a = matrix("a")
762764
>>> b = a[v, 1:3]
763765
>>> b.owner.op.idx_list
764766
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
765767
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
766-
[v, slice(1, 3, None)]
768+
[v, slice(np.int64(1), np.int64(3), None)]
767769
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
768770
Traceback (most recent call last):
769771
pytensor.tensor.exceptions.NotScalarConstantError

0 commit comments

Comments
 (0)