Skip to content

Commit d9b3924

Browse files
committed
Simplify Unique Op
1 parent fa0ab9d commit d9b3924

File tree

1 file changed

+38
-59
lines changed

1 file changed

+38
-59
lines changed

pytensor/tensor/extra_ops.py

+38-59
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from pytensor.tensor.math import max as pt_max
4343
from pytensor.tensor.math import sum as pt_sum
44-
from pytensor.tensor.shape import specify_broadcastable
44+
from pytensor.tensor.shape import Shape_i, specify_broadcastable
4545
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4646
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
4747
from pytensor.tensor.variable import TensorVariable
@@ -1194,23 +1194,22 @@ def __init__(
11941194
self.return_index = return_index
11951195
self.return_inverse = return_inverse
11961196
self.return_counts = return_counts
1197+
if axis is not None and axis < 0:
1198+
raise ValueError("Axis cannot be negative.")
11971199
self.axis = axis
11981200

11991201
def make_node(self, x):
12001202
x = ptb.as_tensor_variable(x)
1201-
self_axis = self.axis
1202-
if self_axis is None:
1203+
axis = self.axis
1204+
if axis is None:
12031205
out_shape = (None,)
12041206
else:
1205-
if self_axis < 0:
1206-
self_axis += x.type.ndim
1207-
if self_axis < 0 or self_axis >= x.type.ndim:
1207+
if axis >= x.type.ndim:
12081208
raise ValueError(
1209-
f"Unique axis {self.axis} is outside of input ndim = {x.type.ndim}"
1209+
f"Axis {axis} out of range for input {x} with ndim={x.type.ndim}."
12101210
)
12111211
out_shape = tuple(
1212-
s if s == 1 and axis != self_axis else None
1213-
for axis, s in enumerate(x.type.shape)
1212+
None if dim == axis else s for dim, s in enumerate(x.type.shape)
12141213
)
12151214

12161215
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
@@ -1224,60 +1223,37 @@ def make_node(self, x):
12241223
return Apply(self, [x], outputs)
12251224

12261225
def perform(self, node, inputs, output_storage):
1227-
x = inputs[0]
1228-
z = output_storage
1229-
param = {}
1230-
if self.return_index:
1231-
param["return_index"] = True
1232-
if self.return_inverse:
1233-
param["return_inverse"] = True
1234-
if self.return_counts:
1235-
param["return_counts"] = True
1236-
if self.axis is not None:
1237-
param["axis"] = self.axis
1238-
outs = np.unique(x, **param)
1239-
if (
1240-
(not self.return_inverse)
1241-
and (not self.return_index)
1242-
and (not self.return_counts)
1243-
):
1244-
z[0][0] = outs
1245-
else:
1226+
[x] = inputs
1227+
outs = np.unique(
1228+
x,
1229+
return_index=self.return_index,
1230+
return_inverse=self.return_inverse,
1231+
return_counts=self.return_counts,
1232+
axis=self.axis,
1233+
)
1234+
if isinstance(outs, tuple):
12461235
for i in range(len(outs)):
1247-
z[i][0] = outs[i]
1236+
output_storage[i][0] = outs[i]
1237+
else:
1238+
output_storage[0][0] = outs
12481239

12491240
def infer_shape(self, fgraph, node, i0_shapes):
1250-
ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes)
1251-
if self.axis is not None:
1252-
self_axis = self.axis
1253-
ndim = len(i0_shapes[0])
1254-
if self_axis < 0:
1255-
self_axis += ndim
1256-
if self_axis < 0 or self_axis >= ndim:
1257-
raise RuntimeError(
1258-
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
1259-
)
1260-
ret[0] = tuple(
1261-
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
1262-
)
1241+
[x_shape] = i0_shapes
1242+
shape0_op = Shape_i(0)
1243+
out_shapes = [(shape0_op(out),) for out in node.outputs]
1244+
1245+
axis = self.axis
1246+
if axis is not None:
1247+
shape = list(x_shape)
1248+
shape[axis] = Shape_i(axis)(node.outputs[0])
1249+
out_shapes[0] = tuple(shape)
1250+
12631251
if self.return_inverse:
1264-
if self.axis is None:
1265-
shape = (prod(i0_shapes[0]),)
1266-
else:
1267-
shape = (i0_shapes[0][self_axis],)
1268-
if self.return_index:
1269-
ret[2] = shape
1270-
return ret
1271-
ret[1] = shape
1272-
return ret
1273-
return ret
1274-
1275-
def __setstate__(self, state):
1276-
self.__dict__.update(state)
1277-
# For backwards compatibility with pickled instances of Unique that
1278-
# did not have the axis parameter specified
1279-
if "axis" not in state:
1280-
self.axis = None
1252+
shape = prod(x_shape) if self.axis is None else x_shape[axis]
1253+
return_index_out_idx = 2 if self.return_index else 1
1254+
out_shapes[return_index_out_idx] = (shape,)
1255+
1256+
return out_shapes
12811257

12821258

12831259
def unique(
@@ -1293,6 +1269,9 @@ def unique(
12931269
* the number of times each unique value comes up in the input array
12941270
12951271
"""
1272+
ar = as_tensor_variable(ar)
1273+
if axis is not None:
1274+
axis = normalize_axis_index(axis, ar.ndim)
12961275
return Unique(return_index, return_inverse, return_counts, axis)(ar)
12971276

12981277

0 commit comments

Comments
 (0)