Skip to content

Rewrite determinant of diagonal matrix as product of diagonal #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
82b7e98
Added det-diag rewrite
tanish1729 Jun 2, 2024
e661ba4
fixed pt.diagonal error
tanish1729 Jun 2, 2024
a1823e9
Added test for rewrite
tanish1729 Jun 5, 2024
ac21b9e
Added test for rewrite
tanish1729 Jun 5, 2024
01e47d7
fixed test
tanish1729 Jun 5, 2024
89f209f
added check for verifying rewrite
tanish1729 Jun 6, 2024
0bd2e26
fixed other failing test
tanish1729 Jun 6, 2024
170a860
added docstring
tanish1729 Jun 7, 2024
0191b97
updated docstring
tanish1729 Jun 7, 2024
7f7c803
fixed mypy error
tanish1729 Jun 7, 2024
8d7f9f7
added det_diag_from_diag and test
tanish1729 Jun 9, 2024
9ad0606
fixed node rewriter name
tanish1729 Jun 9, 2024
c897296
added row/col tests
tanish1729 Jun 11, 2024
6c3cdae
updated check for eye
tanish1729 Jun 11, 2024
6b58cfb
updated rewrite and tests
tanish1729 Jun 15, 2024
4316e75
added check for eye_input and new test for cases where not to apply r…
tanish1729 Jun 21, 2024
20c8505
Merge branch 'pymc-devs:main' into det-diag-rewrite
tanish1729 Jun 22, 2024
6743e0a
does not apply rewrite to specific cases
tanish1729 Jun 22, 2024
ebca339
typecasted test variable
tanish1729 Jun 22, 2024
58212b6
typecast variables
tanish1729 Jun 23, 2024
3ef186d
removed shape known check; fails for rectangle eye
tanish1729 Jun 24, 2024
2c96faf
added new tests for (1,1) eye and rectangle eye
tanish1729 Jun 24, 2024
9a48a72
added helper function for diag from eye_mul
tanish1729 Jun 25, 2024
ffc0f81
updated case for no rewrite which was failing tests
tanish1729 Jun 25, 2024
9f661e3
cleaned code; updated rectangle_eye test which is an invalid rewrite
tanish1729 Jun 26, 2024
b02818d
add check for k in pt.eye
tanish1729 Jun 26, 2024
0dbc28e
Update pytensor/tensor/rewriting/linalg.py
tanish1729 Jun 26, 2024
d4ffb7f
typecasted det_val
tanish1729 Jun 26, 2024
060863a
Merge branch 'det-diag-rewrite' of https://github.com/tanish1729/pyte…
tanish1729 Jun 26, 2024
557fea1
fixed final typecasting
tanish1729 Jun 26, 2024
12ffb8f
Merge branch 'main' into det-diag-rewrite
tanish1729 Jun 30, 2024
27a9864
fixed merge
tanish1729 Jul 1, 2024
6831069
fixed failing rectangle eye test
tanish1729 Jul 1, 2024
9811b88
fixed typo
tanish1729 Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
PatternNodeRewriter,
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
Expand Down Expand Up @@ -39,6 +41,7 @@
solve,
solve_triangular,
)
from pytensor.tensor.subtensor import advanced_set_subtensor


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -384,6 +387,104 @@
raise NotImplementedError # pragma: no cover


def _find_diag_from_eye_mul(potential_mul_input):
# Check if the op is Elemwise and mul
if not (
potential_mul_input.owner is not None
and isinstance(potential_mul_input.owner.op, Elemwise)
and isinstance(potential_mul_input.owner.op.scalar_op, Mul)
):
return None

# Find whether any of the inputs to mul is Eye
inputs_to_mul = potential_mul_input.owner.inputs
eye_input = [
mul_input
for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
]

# Check if 1's are being put on the main diagonal only (k = 0)
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
return None

Check warning on line 409 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L409

Added line #L409 was not covered by tests

# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
return None

# Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
return eye_input, non_eye_inputs


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([det])
def rewrite_det_diag_from_eye_mul(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.

The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
potential_mul_input = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is None:
return None
eye_input, non_eye_inputs = eye_non_eye_inputs

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]

# Checking if original x was scalar/vector/matrix
if useful_non_eye.type.broadcastable[-2:] == (True, True):
# For scalar
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
# For Matrix
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
else:
# For vector
det_val = useful_non_eye.prod(axis=(-1, -2))
det_val = det_val.astype(node.outputs[0].type.dtype)
return [det_val]


arange = ARange("int64")
det_diag_from_diag = PatternNodeRewriter(
(
det,
(
advanced_set_subtensor,
(alloc, 0, "sh1", "sh2"),
"x",
(arange, 0, "stop", 1),
(arange, 0, "stop", 1),
),
),
(prod, "x"),
name="det_diag_from_diag",
allow_multiple_clients=True,
)
register_canonicalize(det_diag_from_diag)
register_stabilize(det_diag_from_diag)
register_specialize(det_diag_from_diag)


@register_canonicalize
@register_stabilize
@register_specialize
Expand Down
89 changes: 89 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,95 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)


@pytest.mark.parametrize(
"shape",
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
)
def test_det_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Calculating determinant value using pt.linalg.det
z_det = pt.linalg.det(y)

# REWRITE TEST
f_rewritten = function([x], z_det, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)

assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_det_diag_from_diag():
x = pt.tensor("x", shape=(None,))
x_diag = pt.diag(x)
y = pt.linalg.det(x_diag)

# REWRITE TEST
f_rewritten = function([x], y, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.random.rand(7).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)

assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_dont_apply_det_diag_rewrite_for_1_1():
x = pt.matrix("x")
x_diag = pt.eye(1, 1) * x
y = pt.linalg.det(x_diag)
f_rewritten = function([x], y, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

assert any(isinstance(node.op, Det) for node in nodes)

# Numeric Value test
x_test = np.random.normal(size=(3, 3)).astype(config.floatX)
x_test_matrix = np.eye(1, 1) * x_test
det_val = np.linalg.det(x_test_matrix)
rewritten_val = f_rewritten(x_test)
assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_det_diag_incorrect_for_rectangle_eye():
x = pt.matrix("x")
x_diag = pt.eye(7, 5) * x
with pytest.raises(ValueError, match="Determinant not defined"):
pt.linalg.det(x_diag)


def test_svd_uv_merge():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
Expand Down
Loading