Skip to content

Commit 684d73e

Browse files
Updated Unique inverse output shape
In numpy 2.0, if axis=None, then np.unique does not flatten the inverse indices returned if return_inverse=True The output shape has been changed to match numpy, dependent on the version of numpy.
1 parent 8229286 commit 684d73e

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from numpy.exceptions import AxisError
6+
from numpy.lib import NumpyVersion
67

78

89
try:
@@ -1241,12 +1242,25 @@ def make_node(self, x):
12411242

12421243
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
12431244
typ = TensorType(dtype="int64", shape=(None,))
1245+
12441246
if self.return_index:
12451247
outputs.append(typ())
1248+
12461249
if self.return_inverse:
1247-
outputs.append(typ())
1250+
if NumpyVersion(np.__version__) >= "2.0.0rc1":
1251+
if axis is None:
1252+
inverse_shape = TensorType(dtype="int64", shape=x.type.shape)
1253+
else:
1254+
inverse_shape = TensorType(
1255+
dtype="int64", shape=(x.type.shape[axis],)
1256+
)
1257+
outputs.append(inverse_shape())
1258+
else:
1259+
outputs.append(typ())
1260+
12481261
if self.return_counts:
12491262
outputs.append(typ())
1263+
12501264
return Apply(self, [x], outputs)
12511265

12521266
def perform(self, node, inputs, output_storage):
@@ -1276,9 +1290,16 @@ def infer_shape(self, fgraph, node, i0_shapes):
12761290
out_shapes[0] = tuple(shape)
12771291

12781292
if self.return_inverse:
1279-
shape = prod(x_shape) if self.axis is None else x_shape[axis]
12801293
return_index_out_idx = 2 if self.return_index else 1
1281-
out_shapes[return_index_out_idx] = (shape,)
1294+
1295+
if self.axis is not None:
1296+
shape = (x_shape[axis],)
1297+
elif NumpyVersion(np.__version__) >= "2.0.0rc1":
1298+
shape = x_shape
1299+
else:
1300+
shape = (prod(x_shape),)
1301+
1302+
out_shapes[return_index_out_idx] = shape
12821303

12831304
return out_shapes
12841305

0 commit comments

Comments
 (0)