|
41 | 41 | from tests.test_rop import break_op
|
42 | 42 |
|
43 | 43 |
|
| 44 | +ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8 |
| 45 | + |
| 46 | + |
44 | 47 | def test_rop_lop():
|
45 | 48 | mx = matrix("mx")
|
46 | 49 | mv = matrix("mv")
|
@@ -568,3 +571,90 @@ def get_pt_function(x, op_name):
|
568 | 571 | op2 = get_pt_function(op1, inv_op_2)
|
569 | 572 | rewritten_out = rewrite_graph(op2)
|
570 | 573 | assert rewritten_out == x
|
| 574 | + |
| 575 | + |
| 576 | +def test_inv_eye_to_eye(): |
| 577 | + x = pt.eye(10) |
| 578 | + x_inv = pt.linalg.inv(x) |
| 579 | + f_rewritten = function([], x_inv, mode="FAST_RUN") |
| 580 | + nodes = f_rewritten.maker.fgraph.apply_nodes |
| 581 | + |
| 582 | + # Rewrite Test |
| 583 | + valid_inverses = (MatrixInverse, MatrixPinv) |
| 584 | + assert not any(isinstance(node.op, valid_inverses) for node in nodes) |
| 585 | + |
| 586 | + # Value Test |
| 587 | + x_test = np.eye(10) |
| 588 | + x_inv_val = np.linalg.inv(x_test) |
| 589 | + rewritten_val = f_rewritten() |
| 590 | + |
| 591 | + assert_allclose( |
| 592 | + x_inv_val, |
| 593 | + rewritten_val, |
| 594 | + atol=1e-3 if config.floatX == "float32" else 1e-8, |
| 595 | + rtol=1e-3 if config.floatX == "float32" else 1e-8, |
| 596 | + ) |
| 597 | + |
| 598 | + |
| 599 | +@pytest.mark.parametrize( |
| 600 | + "shape", |
| 601 | + [(), (7,), (7, 7)], |
| 602 | + ids=["scalar", "vector", "matrix"], |
| 603 | +) |
| 604 | +def test_inv_diag_from_eye_mul(shape): |
| 605 | + # Initializing x based on scalar/vector/matrix |
| 606 | + x = pt.tensor("x", shape=shape) |
| 607 | + x_diag = pt.eye(7) * x |
| 608 | + # Calculating inverse using pt.linalg.inv |
| 609 | + x_inv = pt.linalg.inv(x_diag) |
| 610 | + |
| 611 | + # REWRITE TEST |
| 612 | + f_rewritten = function([x], x_inv, mode="FAST_RUN") |
| 613 | + nodes = f_rewritten.maker.fgraph.apply_nodes |
| 614 | + |
| 615 | + valid_inverses = (MatrixInverse, MatrixPinv) |
| 616 | + assert not any(isinstance(node.op, valid_inverses) for node in nodes) |
| 617 | + |
| 618 | + # NUMERIC VALUE TEST |
| 619 | + if len(shape) == 0: |
| 620 | + x_test = np.array(np.random.rand()).astype(config.floatX) |
| 621 | + elif len(shape) == 1: |
| 622 | + x_test = np.random.rand(*shape).astype(config.floatX) |
| 623 | + else: |
| 624 | + x_test = np.random.rand(*shape).astype(config.floatX) |
| 625 | + x_test_matrix = np.eye(7) * x_test |
| 626 | + inverse_matrix = np.linalg.inv(x_test_matrix) |
| 627 | + rewritten_inverse = f_rewritten(x_test) |
| 628 | + |
| 629 | + assert_allclose( |
| 630 | + inverse_matrix, |
| 631 | + rewritten_inverse, |
| 632 | + atol=ATOL, |
| 633 | + rtol=RTOL, |
| 634 | + ) |
| 635 | + |
| 636 | + |
| 637 | +def test_inv_diag_from_diag(): |
| 638 | + x = pt.dvector("x") |
| 639 | + x_diag = pt.diag(x) |
| 640 | + x_inv = pt.linalg.inv(x_diag) |
| 641 | + |
| 642 | + # REWRITE TEST |
| 643 | + f_rewritten = function([x], x_inv, mode="FAST_RUN") |
| 644 | + nodes = f_rewritten.maker.fgraph.apply_nodes |
| 645 | + |
| 646 | + valid_inverses = (MatrixInverse, MatrixPinv) |
| 647 | + assert not any(isinstance(node.op, valid_inverses) for node in nodes) |
| 648 | + |
| 649 | + # NUMERIC VALUE TEST |
| 650 | + x_test = np.random.rand(10) |
| 651 | + x_test_matrix = np.eye(10) * x_test |
| 652 | + inverse_matrix = np.linalg.inv(x_test_matrix) |
| 653 | + rewritten_inverse = f_rewritten(x_test) |
| 654 | + |
| 655 | + assert_allclose( |
| 656 | + inverse_matrix, |
| 657 | + rewritten_inverse, |
| 658 | + atol=ATOL, |
| 659 | + rtol=RTOL, |
| 660 | + ) |
0 commit comments