Skip to content

Commit 5b61cd4

Browse files
Updated Unique inverse output shape
In numpy 2.0, if axis=None, then np.unique does not flatten the inverse indices returned if return_inverse=True The output shape has been changed to match numpy, dependent on the version of numpy.
1 parent 6a0885d commit 5b61cd4

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from numpy.exceptions import AxisError
6+
from numpy.lib import NumpyVersion
67

78

89
try:
@@ -1240,12 +1241,25 @@ def make_node(self, x):
12401241

12411242
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
12421243
typ = TensorType(dtype="int64", shape=(None,))
1244+
12431245
if self.return_index:
12441246
outputs.append(typ())
1247+
12451248
if self.return_inverse:
1246-
outputs.append(typ())
1249+
if NumpyVersion(np.__version__) >= "2.0.0rc1":
1250+
if axis is None:
1251+
inverse_shape = TensorType(dtype="int64", shape=x.type.shape)
1252+
else:
1253+
inverse_shape = TensorType(
1254+
dtype="int64", shape=(x.type.shape[axis],)
1255+
)
1256+
outputs.append(inverse_shape())
1257+
else:
1258+
outputs.append(typ())
1259+
12471260
if self.return_counts:
12481261
outputs.append(typ())
1262+
12491263
return Apply(self, [x], outputs)
12501264

12511265
def perform(self, node, inputs, output_storage):
@@ -1275,9 +1289,16 @@ def infer_shape(self, fgraph, node, i0_shapes):
12751289
out_shapes[0] = tuple(shape)
12761290

12771291
if self.return_inverse:
1278-
shape = prod(x_shape) if self.axis is None else x_shape[axis]
12791292
return_index_out_idx = 2 if self.return_index else 1
1280-
out_shapes[return_index_out_idx] = (shape,)
1293+
1294+
if self.axis is not None:
1295+
shape = (x_shape[axis],)
1296+
elif NumpyVersion(np.__version__) >= "2.0.0rc1":
1297+
shape = x_shape
1298+
else:
1299+
shape = (prod(x_shape),)
1300+
1301+
out_shapes[return_index_out_idx] = shape
12811302

12821303
return out_shapes
12831304

0 commit comments

Comments
 (0)