Skip to content

Commit ec793a9

Browse files
committed
Group JAX random shape input tests
1 parent 6ed5349 commit ec793a9

File tree

1 file changed

+87
-91
lines changed

1 file changed

+87
-91
lines changed

tests/link/jax/test_random.py

Lines changed: 87 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -805,94 +805,90 @@ def sample_fn(rng, size, dtype, *parameters):
805805
compare_jax_and_py(fgraph, [])
806806

807807

808-
def test_random_concrete_shape():
809-
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
810-
811-
There are three quantities that JAX considers as concrete:
812-
1. Constants known at compile time;
813-
2. The shape of an array.
814-
3. `static_argnums` parameters
815-
This test makes sure that graphs with `RandomVariable`s compile when the
816-
`size` parameter satisfies either of these criteria.
817-
818-
"""
819-
rng = shared(np.random.default_rng(123))
820-
x_pt = pt.dmatrix()
821-
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
822-
jax_fn = compile_random_function([x_pt], out)
823-
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
824-
825-
826-
def test_random_concrete_shape_from_param():
827-
rng = shared(np.random.default_rng(123))
828-
x_pt = pt.dmatrix()
829-
out = pt.random.normal(x_pt, 1, rng=rng)
830-
jax_fn = compile_random_function([x_pt], out)
831-
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
832-
833-
834-
def test_random_concrete_shape_subtensor():
835-
"""JAX should compile when a concrete value is passed for the `size` parameter.
836-
837-
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
838-
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
839-
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
840-
rewrite.
841-
842-
JAX does not accept scalars as `size` or `shape` arguments, so this is a
843-
slight improvement over their API.
844-
845-
"""
846-
rng = shared(np.random.default_rng(123))
847-
x_pt = pt.dmatrix()
848-
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
849-
jax_fn = compile_random_function([x_pt], out)
850-
assert jax_fn(np.ones((2, 3))).shape == (3,)
851-
852-
853-
def test_random_concrete_shape_subtensor_tuple():
854-
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
855-
856-
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
857-
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
858-
scalar inputs into tuples of concrete values using the
859-
`jax_size_parameter_as_tuple` rewrite.
860-
861-
"""
862-
rng = shared(np.random.default_rng(123))
863-
x_pt = pt.dmatrix()
864-
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
865-
jax_fn = compile_random_function([x_pt], out)
866-
assert jax_fn(np.ones((2, 3))).shape == (2,)
867-
868-
869-
@pytest.mark.xfail(
870-
reason="`size_pt` should be specified as a static argument", strict=True
871-
)
872-
def test_random_concrete_shape_graph_input():
873-
rng = shared(np.random.default_rng(123))
874-
size_pt = pt.scalar()
875-
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
876-
jax_fn = compile_random_function([size_pt], out)
877-
assert jax_fn(10).shape == (10,)
878-
879-
880-
def test_constant_shape_after_graph_rewriting():
881-
size = pt.vector("size", shape=(2,), dtype=int)
882-
x = pt.random.normal(size=size)
883-
assert x.type.shape == (None, None)
884-
885-
with pytest.raises(TypeError):
886-
compile_random_function([size], x)([2, 5])
887-
888-
# Rebuild with strict=False so output type is not updated
889-
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
890-
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
891-
assert new_x.type.shape == (None, None)
892-
assert compile_random_function([], new_x)().shape == (2, 5)
893-
894-
# Rebuild with strict=True, so output type is updated
895-
# This uses a different path in the dispatch implementation
896-
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
897-
assert new_x.type.shape == (2, 5)
898-
assert compile_random_function([], new_x)().shape == (2, 5)
808+
class TestRandomShapeInputs:
809+
def test_random_concrete_shape(self):
810+
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
811+
812+
There are three quantities that JAX considers as concrete:
813+
1. Constants known at compile time;
814+
2. The shape of an array.
815+
3. `static_argnums` parameters
816+
This test makes sure that graphs with `RandomVariable`s compile when the
817+
`size` parameter satisfies either of these criteria.
818+
819+
"""
820+
rng = shared(np.random.default_rng(123))
821+
x_pt = pt.dmatrix()
822+
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
823+
jax_fn = compile_random_function([x_pt], out)
824+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
825+
826+
def test_random_concrete_shape_from_param(self):
827+
rng = shared(np.random.default_rng(123))
828+
x_pt = pt.dmatrix()
829+
out = pt.random.normal(x_pt, 1, rng=rng)
830+
jax_fn = compile_random_function([x_pt], out)
831+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
832+
833+
def test_random_concrete_shape_subtensor(self):
834+
"""JAX should compile when a concrete value is passed for the `size` parameter.
835+
836+
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
837+
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
838+
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
839+
rewrite.
840+
841+
JAX does not accept scalars as `size` or `shape` arguments, so this is a
842+
slight improvement over their API.
843+
844+
"""
845+
rng = shared(np.random.default_rng(123))
846+
x_pt = pt.dmatrix()
847+
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
848+
jax_fn = compile_random_function([x_pt], out)
849+
assert jax_fn(np.ones((2, 3))).shape == (3,)
850+
851+
def test_random_concrete_shape_subtensor_tuple(self):
852+
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
853+
854+
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
855+
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
856+
scalar inputs into tuples of concrete values using the
857+
`jax_size_parameter_as_tuple` rewrite.
858+
859+
"""
860+
rng = shared(np.random.default_rng(123))
861+
x_pt = pt.dmatrix()
862+
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
863+
jax_fn = compile_random_function([x_pt], out)
864+
assert jax_fn(np.ones((2, 3))).shape == (2,)
865+
866+
@pytest.mark.xfail(
867+
reason="`size_pt` should be specified as a static argument", strict=True
868+
)
869+
def test_random_concrete_shape_graph_input(self):
870+
rng = shared(np.random.default_rng(123))
871+
size_pt = pt.scalar()
872+
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
873+
jax_fn = compile_random_function([size_pt], out)
874+
assert jax_fn(10).shape == (10,)
875+
876+
def test_constant_shape_after_graph_rewriting(self):
877+
size = pt.vector("size", shape=(2,), dtype=int)
878+
x = pt.random.normal(size=size)
879+
assert x.type.shape == (None, None)
880+
881+
with pytest.raises(TypeError):
882+
compile_random_function([size], x)([2, 5])
883+
884+
# Rebuild with strict=False so output type is not updated
885+
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
886+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
887+
assert new_x.type.shape == (None, None)
888+
assert compile_random_function([], new_x)().shape == (2, 5)
889+
890+
# Rebuild with strict=True, so output type is updated
891+
# This uses a different path in the dispatch implementation
892+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
893+
assert new_x.type.shape == (2, 5)
894+
assert compile_random_function([], new_x)().shape == (2, 5)

0 commit comments

Comments
 (0)