41
41
)
42
42
from pytensor .tensor .math import max as pt_max
43
43
from pytensor .tensor .math import sum as pt_sum
44
- from pytensor .tensor .shape import specify_broadcastable
44
+ from pytensor .tensor .shape import Shape_i , specify_broadcastable
45
45
from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
46
46
from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes , vector
47
47
from pytensor .tensor .variable import TensorVariable
@@ -1194,23 +1194,22 @@ def __init__(
1194
1194
self .return_index = return_index
1195
1195
self .return_inverse = return_inverse
1196
1196
self .return_counts = return_counts
1197
+ if axis is not None and axis < 0 :
1198
+ raise ValueError ("Axis cannot be negative." )
1197
1199
self .axis = axis
1198
1200
1199
1201
def make_node (self , x ):
1200
1202
x = ptb .as_tensor_variable (x )
1201
- self_axis = self .axis
1202
- if self_axis is None :
1203
+ axis = self .axis
1204
+ if axis is None :
1203
1205
out_shape = (None ,)
1204
1206
else :
1205
- if self_axis < 0 :
1206
- self_axis += x .type .ndim
1207
- if self_axis < 0 or self_axis >= x .type .ndim :
1207
+ if axis >= x .type .ndim :
1208
1208
raise ValueError (
1209
- f"Unique axis { self . axis } is outside of input ndim = { x .type .ndim } "
1209
+ f"Axis { axis } out of range for input { x } with ndim= { x .type .ndim } . "
1210
1210
)
1211
1211
out_shape = tuple (
1212
- s if s == 1 and axis != self_axis else None
1213
- for axis , s in enumerate (x .type .shape )
1212
+ None if dim == axis else s for dim , s in enumerate (x .type .shape )
1214
1213
)
1215
1214
1216
1215
outputs = [TensorType (dtype = x .dtype , shape = out_shape )()]
@@ -1224,60 +1223,37 @@ def make_node(self, x):
1224
1223
return Apply (self , [x ], outputs )
1225
1224
1226
1225
def perform (self , node , inputs , output_storage ):
1227
- x = inputs [0 ]
1228
- z = output_storage
1229
- param = {}
1230
- if self .return_index :
1231
- param ["return_index" ] = True
1232
- if self .return_inverse :
1233
- param ["return_inverse" ] = True
1234
- if self .return_counts :
1235
- param ["return_counts" ] = True
1236
- if self .axis is not None :
1237
- param ["axis" ] = self .axis
1238
- outs = np .unique (x , ** param )
1239
- if (
1240
- (not self .return_inverse )
1241
- and (not self .return_index )
1242
- and (not self .return_counts )
1243
- ):
1244
- z [0 ][0 ] = outs
1245
- else :
1226
+ [x ] = inputs
1227
+ outs = np .unique (
1228
+ x ,
1229
+ return_index = self .return_index ,
1230
+ return_inverse = self .return_inverse ,
1231
+ return_counts = self .return_counts ,
1232
+ axis = self .axis ,
1233
+ )
1234
+ if isinstance (outs , tuple ):
1246
1235
for i in range (len (outs )):
1247
- z [i ][0 ] = outs [i ]
1236
+ output_storage [i ][0 ] = outs [i ]
1237
+ else :
1238
+ output_storage [0 ][0 ] = outs
1248
1239
1249
1240
def infer_shape (self , fgraph , node , i0_shapes ):
1250
- ret = fgraph .shape_feature .default_infer_shape (fgraph , node , i0_shapes )
1251
- if self .axis is not None :
1252
- self_axis = self .axis
1253
- ndim = len (i0_shapes [0 ])
1254
- if self_axis < 0 :
1255
- self_axis += ndim
1256
- if self_axis < 0 or self_axis >= ndim :
1257
- raise RuntimeError (
1258
- f"Unique axis `{ self .axis } ` is outside of input ndim = { ndim } ."
1259
- )
1260
- ret [0 ] = tuple (
1261
- fgraph .shape_feature .shape_ir (i , node .outputs [0 ]) for i in range (ndim )
1262
- )
1241
+ [x_shape ] = i0_shapes
1242
+ shape0_op = Shape_i (0 )
1243
+ out_shapes = [(shape0_op (out ),) for out in node .outputs ]
1244
+
1245
+ axis = self .axis
1246
+ if axis is not None :
1247
+ shape = list (x_shape )
1248
+ shape [axis ] = Shape_i (axis )(node .outputs [0 ])
1249
+ out_shapes [0 ] = tuple (shape )
1250
+
1263
1251
if self .return_inverse :
1264
- if self .axis is None :
1265
- shape = (prod (i0_shapes [0 ]),)
1266
- else :
1267
- shape = (i0_shapes [0 ][self_axis ],)
1268
- if self .return_index :
1269
- ret [2 ] = shape
1270
- return ret
1271
- ret [1 ] = shape
1272
- return ret
1273
- return ret
1274
-
1275
- def __setstate__ (self , state ):
1276
- self .__dict__ .update (state )
1277
- # For backwards compatibility with pickled instances of Unique that
1278
- # did not have the axis parameter specified
1279
- if "axis" not in state :
1280
- self .axis = None
1252
+ shape = prod (x_shape ) if self .axis is None else x_shape [axis ]
1253
+ return_index_out_idx = 2 if self .return_index else 1
1254
+ out_shapes [return_index_out_idx ] = (shape ,)
1255
+
1256
+ return out_shapes
1281
1257
1282
1258
1283
1259
def unique (
@@ -1293,6 +1269,9 @@ def unique(
1293
1269
* the number of times each unique value comes up in the input array
1294
1270
1295
1271
"""
1272
+ ar = as_tensor_variable (ar )
1273
+ if axis is not None :
1274
+ axis = normalize_axis_index (axis , ar .ndim )
1296
1275
return Unique (return_index , return_inverse , return_counts , axis )(ar )
1297
1276
1298
1277
0 commit comments