|
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | from numpy.exceptions import AxisError
|
| 6 | +from numpy.lib import NumpyVersion |
6 | 7 |
|
7 | 8 |
|
8 | 9 | try:
|
@@ -1240,12 +1241,25 @@ def make_node(self, x):
|
1240 | 1241 |
|
1241 | 1242 | outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
|
1242 | 1243 | typ = TensorType(dtype="int64", shape=(None,))
|
| 1244 | + |
1243 | 1245 | if self.return_index:
|
1244 | 1246 | outputs.append(typ())
|
| 1247 | + |
1245 | 1248 | if self.return_inverse:
|
1246 |
| - outputs.append(typ()) |
| 1249 | + if NumpyVersion(np.__version__) >= "2.0.0rc1": |
| 1250 | + if axis is None: |
| 1251 | + inverse_shape = TensorType(dtype="int64", shape=x.type.shape) |
| 1252 | + else: |
| 1253 | + inverse_shape = TensorType( |
| 1254 | + dtype="int64", shape=(x.type.shape[axis],) |
| 1255 | + ) |
| 1256 | + outputs.append(inverse_shape()) |
| 1257 | + else: |
| 1258 | + outputs.append(typ()) |
| 1259 | + |
1247 | 1260 | if self.return_counts:
|
1248 | 1261 | outputs.append(typ())
|
| 1262 | + |
1249 | 1263 | return Apply(self, [x], outputs)
|
1250 | 1264 |
|
1251 | 1265 | def perform(self, node, inputs, output_storage):
|
@@ -1275,9 +1289,16 @@ def infer_shape(self, fgraph, node, i0_shapes):
|
1275 | 1289 | out_shapes[0] = tuple(shape)
|
1276 | 1290 |
|
1277 | 1291 | if self.return_inverse:
|
1278 |
| - shape = prod(x_shape) if self.axis is None else x_shape[axis] |
1279 | 1292 | return_index_out_idx = 2 if self.return_index else 1
|
1280 |
| - out_shapes[return_index_out_idx] = (shape,) |
| 1293 | + |
| 1294 | + if self.axis is not None: |
| 1295 | + shape = (x_shape[axis],) |
| 1296 | + elif NumpyVersion(np.__version__) >= "2.0.0rc1": |
| 1297 | + shape = x_shape |
| 1298 | + else: |
| 1299 | + shape = (prod(x_shape),) |
| 1300 | + |
| 1301 | + out_shapes[return_index_out_idx] = shape |
1281 | 1302 |
|
1282 | 1303 | return out_shapes
|
1283 | 1304 |
|
|
0 commit comments