|
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | from numpy.exceptions import AxisError
|
| 6 | +from numpy.lib import NumpyVersion |
6 | 7 |
|
7 | 8 |
|
8 | 9 | try:
|
@@ -1241,12 +1242,25 @@ def make_node(self, x):
|
1241 | 1242 |
|
1242 | 1243 | outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
|
1243 | 1244 | typ = TensorType(dtype="int64", shape=(None,))
|
| 1245 | + |
1244 | 1246 | if self.return_index:
|
1245 | 1247 | outputs.append(typ())
|
| 1248 | + |
1246 | 1249 | if self.return_inverse:
|
1247 |
| - outputs.append(typ()) |
| 1250 | + if NumpyVersion(np.__version__) >= "2.0.0rc1": |
| 1251 | + if axis is None: |
| 1252 | + inverse_shape = TensorType(dtype="int64", shape=x.type.shape) |
| 1253 | + else: |
| 1254 | + inverse_shape = TensorType( |
| 1255 | + dtype="int64", shape=(x.type.shape[axis],) |
| 1256 | + ) |
| 1257 | + outputs.append(inverse_shape()) |
| 1258 | + else: |
| 1259 | + outputs.append(typ()) |
| 1260 | + |
1248 | 1261 | if self.return_counts:
|
1249 | 1262 | outputs.append(typ())
|
| 1263 | + |
1250 | 1264 | return Apply(self, [x], outputs)
|
1251 | 1265 |
|
1252 | 1266 | def perform(self, node, inputs, output_storage):
|
@@ -1276,9 +1290,16 @@ def infer_shape(self, fgraph, node, i0_shapes):
|
1276 | 1290 | out_shapes[0] = tuple(shape)
|
1277 | 1291 |
|
1278 | 1292 | if self.return_inverse:
|
1279 |
| - shape = prod(x_shape) if self.axis is None else x_shape[axis] |
1280 | 1293 | return_index_out_idx = 2 if self.return_index else 1
|
1281 |
| - out_shapes[return_index_out_idx] = (shape,) |
| 1294 | + |
| 1295 | + if self.axis is not None: |
| 1296 | + shape = (x_shape[axis],) |
| 1297 | + elif NumpyVersion(np.__version__) >= "2.0.0rc1": |
| 1298 | + shape = x_shape |
| 1299 | + else: |
| 1300 | + shape = (prod(x_shape),) |
| 1301 | + |
| 1302 | + out_shapes[return_index_out_idx] = shape |
1282 | 1303 |
|
1283 | 1304 | return out_shapes
|
1284 | 1305 |
|
|
0 commit comments