Skip to content

Commit 625e98c

Browse files
committed
updated tests
1 parent b65d08c commit 625e98c

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/tensor/rewriting/test_linalg.py

+90
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
from tests.test_rop import break_op
4242

4343

44+
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
45+
46+
4447
def test_rop_lop():
4548
mx = matrix("mx")
4649
mv = matrix("mv")
@@ -568,3 +571,90 @@ def get_pt_function(x, op_name):
568571
op2 = get_pt_function(op1, inv_op_2)
569572
rewritten_out = rewrite_graph(op2)
570573
assert rewritten_out == x
574+
575+
576+
def test_inv_eye_to_eye():
577+
x = pt.eye(10)
578+
x_inv = pt.linalg.inv(x)
579+
f_rewritten = function([], x_inv, mode="FAST_RUN")
580+
nodes = f_rewritten.maker.fgraph.apply_nodes
581+
582+
# Rewrite Test
583+
valid_inverses = (MatrixInverse, MatrixPinv)
584+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
585+
586+
# Value Test
587+
x_test = np.eye(10)
588+
x_inv_val = np.linalg.inv(x_test)
589+
rewritten_val = f_rewritten()
590+
591+
assert_allclose(
592+
x_inv_val,
593+
rewritten_val,
594+
atol=1e-3 if config.floatX == "float32" else 1e-8,
595+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
596+
)
597+
598+
599+
@pytest.mark.parametrize(
600+
"shape",
601+
[(), (7,), (7, 7)],
602+
ids=["scalar", "vector", "matrix"],
603+
)
604+
def test_inv_diag_from_eye_mul(shape):
605+
# Initializing x based on scalar/vector/matrix
606+
x = pt.tensor("x", shape=shape)
607+
x_diag = pt.eye(7) * x
608+
# Calculating inverse using pt.linalg.inv
609+
x_inv = pt.linalg.inv(x_diag)
610+
611+
# REWRITE TEST
612+
f_rewritten = function([x], x_inv, mode="FAST_RUN")
613+
nodes = f_rewritten.maker.fgraph.apply_nodes
614+
615+
valid_inverses = (MatrixInverse, MatrixPinv)
616+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
617+
618+
# NUMERIC VALUE TEST
619+
if len(shape) == 0:
620+
x_test = np.array(np.random.rand()).astype(config.floatX)
621+
elif len(shape) == 1:
622+
x_test = np.random.rand(*shape).astype(config.floatX)
623+
else:
624+
x_test = np.random.rand(*shape).astype(config.floatX)
625+
x_test_matrix = np.eye(7) * x_test
626+
inverse_matrix = np.linalg.inv(x_test_matrix)
627+
rewritten_inverse = f_rewritten(x_test)
628+
629+
assert_allclose(
630+
inverse_matrix,
631+
rewritten_inverse,
632+
atol=ATOL,
633+
rtol=RTOL,
634+
)
635+
636+
637+
def test_inv_diag_from_diag():
638+
x = pt.dvector("x")
639+
x_diag = pt.diag(x)
640+
x_inv = pt.linalg.inv(x_diag)
641+
642+
# REWRITE TEST
643+
f_rewritten = function([x], x_inv, mode="FAST_RUN")
644+
nodes = f_rewritten.maker.fgraph.apply_nodes
645+
646+
valid_inverses = (MatrixInverse, MatrixPinv)
647+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
648+
649+
# NUMERIC VALUE TEST
650+
x_test = np.random.rand(10)
651+
x_test_matrix = np.eye(10) * x_test
652+
inverse_matrix = np.linalg.inv(x_test_matrix)
653+
rewritten_inverse = f_rewritten(x_test)
654+
655+
assert_allclose(
656+
inverse_matrix,
657+
rewritten_inverse,
658+
atol=ATOL,
659+
rtol=RTOL,
660+
)

0 commit comments

Comments
 (0)