Skip to content

Commit e468381

Browse files
committed
Don't import Test for helper
This triggers the remaining tests in pytest
1 parent 907cd1a commit e468381

File tree

3 files changed

+31
-31
lines changed

3 files changed

+31
-31
lines changed

tests/link/jax/test_tensor_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.graph.op import get_test_value
1515
from pytensor.tensor.type import iscalar, matrix, scalar, vector
1616
from tests.link.jax.test_basic import compare_jax_and_py
17-
from tests.tensor.test_basic import TestAlloc
17+
from tests.tensor.test_basic import check_alloc_runtime_broadcast
1818

1919

2020
def test_jax_Alloc():
@@ -54,7 +54,7 @@ def compare_shape_dtype(x, y):
5454

5555

5656
def test_alloc_runtime_broadcast():
57-
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
57+
check_alloc_runtime_broadcast(get_mode("JAX"))
5858

5959

6060
def test_jax_MakeVector():

tests/link/numba/test_tensor_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
compare_shape_dtype,
1717
set_test_value,
1818
)
19-
from tests.tensor.test_basic import TestAlloc
19+
from tests.tensor.test_basic import check_alloc_runtime_broadcast
2020

2121

2222
pytest.importorskip("numba")
@@ -52,7 +52,7 @@ def test_Alloc(v, shape):
5252

5353

5454
def test_alloc_runtime_broadcast():
55-
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
55+
check_alloc_runtime_broadcast(get_mode("NUMBA"))
5656

5757

5858
def test_AllocEmpty():

tests/tensor/test_basic.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,32 @@ def test_masked_array_not_implemented(
716716
ptb.as_tensor(x)
717717

718718

719+
def check_alloc_runtime_broadcast(mode):
720+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
721+
floatX = config.floatX
722+
x_v = vector("x", shape=(None,))
723+
724+
out = alloc(x_v, 5, 3)
725+
f = pytensor.function([x_v], out, mode=mode)
726+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
727+
728+
np.testing.assert_array_equal(
729+
f(x=np.zeros((3,), dtype=floatX)),
730+
np.zeros((5, 3), dtype=floatX),
731+
)
732+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
733+
f(x=np.zeros((1,), dtype=floatX))
734+
735+
out = alloc(specify_shape(x_v, (1,)), 5, 3)
736+
f = pytensor.function([x_v], out, mode=mode)
737+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
738+
739+
np.testing.assert_array_equal(
740+
f(x=np.zeros((1,), dtype=floatX)),
741+
np.zeros((5, 3), dtype=floatX),
742+
)
743+
744+
719745
class TestAlloc:
720746
dtype = config.floatX
721747
mode = mode_opt
@@ -729,32 +755,6 @@ def check_allocs_in_fgraph(fgraph, n):
729755
== n
730756
)
731757

732-
@staticmethod
733-
def check_runtime_broadcast(mode):
734-
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
735-
floatX = config.floatX
736-
x_v = vector("x", shape=(None,))
737-
738-
out = alloc(x_v, 5, 3)
739-
f = pytensor.function([x_v], out, mode=mode)
740-
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
741-
742-
np.testing.assert_array_equal(
743-
f(x=np.zeros((3,), dtype=floatX)),
744-
np.zeros((5, 3), dtype=floatX),
745-
)
746-
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
747-
f(x=np.zeros((1,), dtype=floatX))
748-
749-
out = alloc(specify_shape(x_v, (1,)), 5, 3)
750-
f = pytensor.function([x_v], out, mode=mode)
751-
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
752-
753-
np.testing.assert_array_equal(
754-
f(x=np.zeros((1,), dtype=floatX)),
755-
np.zeros((5, 3), dtype=floatX),
756-
)
757-
758758
def setup_method(self):
759759
self.rng = np.random.default_rng(seed=utt.fetch_seed())
760760

@@ -911,7 +911,7 @@ def test_alloc_of_view_linker(self):
911911

912912
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
913913
def test_runtime_broadcast(self, mode):
914-
self.check_runtime_broadcast(mode)
914+
check_alloc_runtime_broadcast(mode)
915915

916916

917917
def test_infer_static_shape():

0 commit comments

Comments
 (0)