Skip to content

Commit a5355b6

Browse files
Remove depreciated AllocDiag Op, rename AllocDiag2 -> AllocDiag
1 parent e810df0 commit a5355b6

File tree

3 files changed

+7
-112
lines changed

3 files changed

+7
-112
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.tensor import get_vector_length
1111
from pytensor.tensor.basic import (
1212
Alloc,
13-
AllocDiag2,
13+
AllocDiag,
1414
AllocEmpty,
1515
ARange,
1616
ExtractDiag,
@@ -212,10 +212,10 @@ def tri(*args):
212212

213213

214214
@register_specialize
215-
@node_rewriter([AllocDiag2])
215+
@node_rewriter([AllocDiag])
216216
def eagerly_inline_alloc_diag(fgraph, node):
217217
"""
218-
Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified
218+
Inline `AllocDiag` OpFromGraph into the graph so the component Ops can themselves be jaxified
219219
220220
Parameters
221221
----------

pytensor/tensor/basic.py

Lines changed: 2 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -3727,112 +3727,7 @@ def trace(a, offset=0, axis1=0, axis2=1):
37273727
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
37283728

37293729

3730-
class AllocDiag(Op):
3731-
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
3732-
3733-
__props__ = ("offset", "axis1", "axis2")
3734-
3735-
def __init__(self, offset=0, axis1=0, axis2=1):
3736-
"""
3737-
Parameters
3738-
----------
3739-
offset: int
3740-
Offset of the diagonal from the main diagonal defined by `axis1`
3741-
and `axis2`. Can be positive or negative. Defaults to main
3742-
diagonal (i.e. 0).
3743-
axis1: int
3744-
Axis to be used as the first axis of the 2-D sub-arrays to which
3745-
the diagonals will be allocated. Defaults to first axis (i.e. 0).
3746-
axis2: int
3747-
Axis to be used as the second axis of the 2-D sub-arrays to which
3748-
the diagonals will be allocated. Defaults to second axis (i.e. 1).
3749-
"""
3750-
warnings.warn(
3751-
"AllocDiag is deprecated. Use `alloc_diag` instead",
3752-
FutureWarning,
3753-
)
3754-
self.offset = offset
3755-
if axis1 < 0 or axis2 < 0:
3756-
raise NotImplementedError("AllocDiag does not support negative axis")
3757-
if axis1 == axis2:
3758-
raise ValueError("axis1 and axis2 cannot be the same")
3759-
self.axis1 = axis1
3760-
self.axis2 = axis2
3761-
3762-
def make_node(self, diag):
3763-
diag = as_tensor_variable(diag)
3764-
if diag.type.ndim < 1:
3765-
raise ValueError(
3766-
"AllocDiag needs an input with 1 or more dimensions", diag.type
3767-
)
3768-
return Apply(
3769-
self,
3770-
[diag],
3771-
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
3772-
)
3773-
3774-
def perform(self, node, inputs, outputs):
3775-
(x,) = inputs
3776-
(z,) = outputs
3777-
3778-
axis1 = np.minimum(self.axis1, self.axis2)
3779-
axis2 = np.maximum(self.axis1, self.axis2)
3780-
offset = self.offset
3781-
3782-
# Create array with one extra dimension for resulting matrix
3783-
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
3784-
result = np.zeros(result_shape, dtype=x.dtype)
3785-
3786-
# Create slice for diagonal in final 2 axes
3787-
idxs = np.arange(x.shape[-1])
3788-
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
3789-
idxs + np.maximum(0, -offset),
3790-
idxs + np.maximum(0, offset),
3791-
]
3792-
3793-
# Fill in final 2 axes with x
3794-
result[tuple(diagonal_slice)] = x
3795-
3796-
if len(x.shape) > 1:
3797-
# Re-order axes so they correspond to diagonals at axis1, axis2
3798-
axes = list(range(len(x.shape[:-1])))
3799-
last_idx = axes[-1]
3800-
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
3801-
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
3802-
result = result.transpose(axes)
3803-
3804-
z[0] = result
3805-
3806-
def grad(self, inputs, gout):
3807-
(gz,) = gout
3808-
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]
3809-
3810-
def infer_shape(self, fgraph, nodes, shapes):
3811-
(x_shape,) = shapes
3812-
axis1 = np.minimum(self.axis1, self.axis2)
3813-
axis2 = np.maximum(self.axis1, self.axis2)
3814-
3815-
result_shape = list(x_shape[:-1])
3816-
diag_shape = x_shape[-1] + abs(self.offset)
3817-
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
3818-
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
3819-
return [tuple(result_shape)]
3820-
3821-
def __setstate__(self, state):
3822-
if "view_map" in state:
3823-
del state["view_map"]
3824-
3825-
self.__dict__.update(state)
3826-
3827-
if "offset" not in state:
3828-
self.offset = 0
3829-
if "axis1" not in state:
3830-
self.axis1 = 0
3831-
if "axis2" not in state:
3832-
self.axis2 = 1
3833-
3834-
3835-
class AllocDiag2(OpFromGraph):
3730+
class AllocDiag(OpFromGraph):
38363731
"""
38373732
Wrapper Op for alloc_diag graphs
38383733
"""
@@ -3883,7 +3778,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38833778
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
38843779
result = result.transpose(axes)
38853780

3886-
return AllocDiag2(
3781+
return AllocDiag(
38873782
inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2
38883783
)(diag)
38893784

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from pytensor.scalar.basic import Mul
1212
from pytensor.tensor.basic import (
13-
AllocDiag2,
13+
AllocDiag,
1414
Eye,
1515
TensorVariable,
1616
diagonal,
@@ -475,7 +475,7 @@ def rewrite_det_diag_to_prod_diag(fgraph, node):
475475
# Check for use of pt.diag first
476476
if (
477477
inputs.owner
478-
and isinstance(inputs.owner.op, AllocDiag2)
478+
and isinstance(inputs.owner.op, AllocDiag)
479479
and inputs.owner.op.offset == 0
480480
):
481481
diag_input = inputs.owner.inputs[0]

0 commit comments

Comments
 (0)