Skip to content

Commit 08d424b

Browse files
Preserve numpy < 2.0 Unique inverse output shape
In numpy 2.0, if axis=None, then np.unique does not flatten the inverse indices returned if return_inverse=True A helper function has been added to npy_2_compat.py to mimic the output of `np.unique` from version of numpy before 2.0
1 parent 1a99fe8 commit 08d424b

File tree

3 files changed

+48
-11
lines changed

3 files changed

+48
-11
lines changed

pytensor/npy_2_compat.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,28 @@
6363
numpy_maxdims = 64 if using_numpy_2 else 32
6464

6565

66+
# function that replicates np.unique from numpy < 2.0
67+
def old_np_unique(
68+
arr, return_index=False, return_inverse=False, return_counts=False, axis=None
69+
):
70+
"""Replicate np.unique from numpy versions < 2.0"""
71+
if not return_inverse or not using_numpy_2:
72+
return np.unique(arr, return_index, return_inverse, return_counts, axis)
73+
74+
outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis))
75+
76+
inv_idx = 2 if return_index else 1
77+
78+
if axis is None:
79+
outs[inv_idx] = np.ravel(outs[inv_idx])
80+
else:
81+
inv_shape = (arr.shape[axis],)
82+
outs[inv_idx] = outs[inv_idx].reshape(inv_shape)
83+
84+
return tuple(outs)
85+
86+
87+
# compatibility header for C code
6688
def npy_2_compat_header() -> str:
6789
"""Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x"""
6890
return dedent("""

pytensor/tensor/extra_ops.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
normalize_axis_index,
2121
npy_2_compat_header,
2222
numpy_axis_is_none_flag,
23+
old_np_unique,
2324
)
2425
from pytensor.raise_op import Assert
2526
from pytensor.scalar import int64 as int_t
@@ -1182,6 +1183,9 @@ class Unique(Op):
11821183
"""
11831184
Wraps `numpy.unique`.
11841185
1186+
The indices returned when `return_inverse` is True are ravelled
1187+
to match the behavior of `numpy.unique` from before numpy version 2.0.
1188+
11851189
Examples
11861190
--------
11871191
>>> import numpy as np
@@ -1227,17 +1231,21 @@ def make_node(self, x):
12271231

12281232
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
12291233
typ = TensorType(dtype="int64", shape=(None,))
1234+
12301235
if self.return_index:
12311236
outputs.append(typ())
1237+
12321238
if self.return_inverse:
12331239
outputs.append(typ())
1240+
12341241
if self.return_counts:
12351242
outputs.append(typ())
1243+
12361244
return Apply(self, [x], outputs)
12371245

12381246
def perform(self, node, inputs, output_storage):
12391247
[x] = inputs
1240-
outs = np.unique(
1248+
outs = old_np_unique(
12411249
x,
12421250
return_index=self.return_index,
12431251
return_inverse=self.return_inverse,
@@ -1262,9 +1270,14 @@ def infer_shape(self, fgraph, node, i0_shapes):
12621270
out_shapes[0] = tuple(shape)
12631271

12641272
if self.return_inverse:
1265-
shape = prod(x_shape) if self.axis is None else x_shape[axis]
12661273
return_index_out_idx = 2 if self.return_index else 1
1267-
out_shapes[return_index_out_idx] = (shape,)
1274+
1275+
if self.axis is not None:
1276+
shape = (x_shape[axis],)
1277+
else:
1278+
shape = (prod(x_shape),)
1279+
1280+
out_shapes[return_index_out_idx] = shape
12681281

12691282
return out_shapes
12701283

tests/tensor/test_extra_ops.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Constant, applys_between, equal_computations
12+
from pytensor.npy_2_compat import old_np_unique
1213
from pytensor.raise_op import Assert
1314
from pytensor.tensor import alloc
1415
from pytensor.tensor.elemwise import DimShuffle
@@ -885,18 +886,19 @@ def setup_method(self):
885886
)
886887
def test_basic_vector(self, x, inp, axis):
887888
list_outs_expected = [
888-
np.unique(inp, axis=axis),
889-
np.unique(inp, True, axis=axis),
890-
np.unique(inp, False, True, axis=axis),
891-
np.unique(inp, True, True, axis=axis),
892-
np.unique(inp, False, False, True, axis=axis),
893-
np.unique(inp, True, False, True, axis=axis),
894-
np.unique(inp, False, True, True, axis=axis),
895-
np.unique(inp, True, True, True, axis=axis),
889+
old_np_unique(inp, axis=axis),
890+
old_np_unique(inp, True, axis=axis),
891+
old_np_unique(inp, False, True, axis=axis),
892+
old_np_unique(inp, True, True, axis=axis),
893+
old_np_unique(inp, False, False, True, axis=axis),
894+
old_np_unique(inp, True, False, True, axis=axis),
895+
old_np_unique(inp, False, True, True, axis=axis),
896+
old_np_unique(inp, True, True, True, axis=axis),
896897
]
897898
for params, outs_expected in zip(
898899
self.op_params, list_outs_expected, strict=True
899900
):
901+
print(params)
900902
out = pt.unique(x, *params, axis=axis)
901903
f = pytensor.function(inputs=[x], outputs=out)
902904
outs = f(inp)

0 commit comments

Comments
 (0)