Skip to content

Commit dbfe92c

Browse files
Fix bug in rewrite_det_diag_to_prod_diag where batch case was incorrectly passing
1 parent b0abe17 commit dbfe92c

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,19 +404,44 @@ def _find_diag_from_eye_mul(potential_mul_input):
404404
eye_input = [
405405
mul_input
406406
for mul_input in inputs_to_mul
407-
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
407+
if mul_input.owner
408+
and (
409+
isinstance(mul_input.owner.op, Eye)
410+
or
411+
# This whole condition checks if there is an Eye hiding inside a DimShuffle.
412+
# This arises from batched elementwise multiplication between a tensor and an eye, e.g.:
413+
# tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites.
414+
(
415+
isinstance(mul_input.owner.op, DimShuffle)
416+
and mul_input.owner.inputs[0].owner is not None
417+
and isinstance(mul_input.owner.inputs[0].owner.op, Eye)
418+
)
419+
)
408420
]
409-
410-
# Check if 1's are being put on the main diagonal only (k = 0)
411-
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
421+
if not eye_input:
412422
return None
413423

414-
# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
415-
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
424+
eye_input = eye_input[0]
425+
426+
# If this multiplication came from a batched operation, it will be wrapped in a DimShuffle
427+
if isinstance(eye_input.owner.op, DimShuffle):
428+
inner_eye = eye_input.owner.inputs[0]
429+
if not isinstance(inner_eye.owner.op, Eye):
430+
return None
431+
# Check if 1's are being put on the main diagonal only (k = 0)
432+
# and if the identity matrix is degenerate (column or row matrix)
433+
if getattr(
434+
inner_eye.owner.inputs[-1], "data", -1
435+
).item() != 0 or inner_eye.broadcastable[-2:] != (False, False):
436+
return None
437+
438+
elif getattr(
439+
eye_input.owner.inputs[-1], "data", -1
440+
).item() != 0 or eye_input.broadcastable[-2:] != (False, False):
416441
return None
417442

418443
# Get all non Eye inputs (scalars/matrices/vectors)
419-
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
444+
non_eye_inputs = list(set(inputs_to_mul) - {eye_input})
420445
return eye_input, non_eye_inputs
421446

422447

@@ -448,15 +473,22 @@ def rewrite_det_diag_to_prod_diag(fgraph, node):
448473
inputs = node.inputs[0]
449474

450475
# Check for use of pt.diag first
451-
if inputs.owner and isinstance(inputs.owner.op, AllocDiag2):
476+
if (
477+
inputs.owner
478+
and isinstance(inputs.owner.op, AllocDiag2)
479+
and inputs.owner.op.offset == 0
480+
):
452481
diag_input = inputs.owner.inputs[0]
482+
diag_input.dprint()
453483
det_val = diag_input.prod(axis=-1)
454484
return [det_val]
455485

456486
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
457487
inputs_or_none = _find_diag_from_eye_mul(inputs)
488+
458489
if inputs_or_none is None:
459490
return None
491+
460492
eye_input, non_eye_inputs = inputs_or_none
461493

462494
# Dealing with only one other input

tests/tensor/rewriting/test_linalg.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,20 +396,26 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
396396

397397
@pytest.mark.parametrize(
398398
"shape",
399-
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
399+
[(), (7,), (1, 7), (7, 1), (7, 7), pytest.param((3, 7, 7))],
400400
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
401401
)
402402
def test_det_diag_from_eye_mul(shape):
403403
# Initializing x based on scalar/vector/matrix
404404
x = pt.tensor("x", shape=shape)
405405
y = pt.eye(7) * x
406+
406407
# Calculating determinant value using pt.linalg.det
407408
z_det = pt.linalg.det(y)
408409

409410
# REWRITE TEST
410-
f_rewritten = function([x], z_det, mode="FAST_RUN")
411+
with pytensor.config.change_flags(optimizer_verbose=True):
412+
f_rewritten = function([x], z_det, mode="FAST_RUN")
411413
nodes = f_rewritten.maker.fgraph.apply_nodes
412-
assert not any(isinstance(node.op, Det) for node in nodes)
414+
415+
assert not any(
416+
isinstance(node.op, Det) or isinstance(getattr(node.op, "core_op", None), Det)
417+
for node in nodes
418+
)
413419

414420
# NUMERIC VALUE TEST
415421
if len(shape) == 0:
@@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape):
418424
x_test = np.random.rand(*shape).astype(config.floatX)
419425
else:
420426
x_test = np.random.rand(*shape).astype(config.floatX)
427+
421428
x_test_matrix = np.eye(7) * x_test
422429
det_val = np.linalg.det(x_test_matrix)
423430
rewritten_val = f_rewritten(x_test)

0 commit comments

Comments
 (0)