Skip to content

Commit c7bc290

Browse files
committed
Group JAX random shape input tests
1 parent 03cab87 commit c7bc290

File tree

1 file changed

+87
-91
lines changed

1 file changed

+87
-91
lines changed

tests/link/jax/test_random.py

Lines changed: 87 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -836,94 +836,90 @@ def sample_fn(rng, size, dtype, *parameters):
836836
compare_jax_and_py([], [out], [])
837837

838838

839-
def test_random_concrete_shape():
840-
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
841-
842-
There are three quantities that JAX considers as concrete:
843-
1. Constants known at compile time;
844-
2. The shape of an array.
845-
3. `static_argnums` parameters
846-
This test makes sure that graphs with `RandomVariable`s compile when the
847-
`size` parameter satisfies either of these criteria.
848-
849-
"""
850-
rng = shared(np.random.default_rng(123))
851-
x_pt = pt.dmatrix()
852-
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
853-
jax_fn = compile_random_function([x_pt], out)
854-
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
855-
856-
857-
def test_random_concrete_shape_from_param():
858-
rng = shared(np.random.default_rng(123))
859-
x_pt = pt.dmatrix()
860-
out = pt.random.normal(x_pt, 1, rng=rng)
861-
jax_fn = compile_random_function([x_pt], out)
862-
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
863-
864-
865-
def test_random_concrete_shape_subtensor():
866-
"""JAX should compile when a concrete value is passed for the `size` parameter.
867-
868-
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
869-
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
870-
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
871-
rewrite.
872-
873-
JAX does not accept scalars as `size` or `shape` arguments, so this is a
874-
slight improvement over their API.
875-
876-
"""
877-
rng = shared(np.random.default_rng(123))
878-
x_pt = pt.dmatrix()
879-
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
880-
jax_fn = compile_random_function([x_pt], out)
881-
assert jax_fn(np.ones((2, 3))).shape == (3,)
882-
883-
884-
def test_random_concrete_shape_subtensor_tuple():
885-
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
886-
887-
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
888-
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
889-
scalar inputs into tuples of concrete values using the
890-
`jax_size_parameter_as_tuple` rewrite.
891-
892-
"""
893-
rng = shared(np.random.default_rng(123))
894-
x_pt = pt.dmatrix()
895-
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
896-
jax_fn = compile_random_function([x_pt], out)
897-
assert jax_fn(np.ones((2, 3))).shape == (2,)
898-
899-
900-
@pytest.mark.xfail(
901-
reason="`size_pt` should be specified as a static argument", strict=True
902-
)
903-
def test_random_concrete_shape_graph_input():
904-
rng = shared(np.random.default_rng(123))
905-
size_pt = pt.scalar()
906-
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
907-
jax_fn = compile_random_function([size_pt], out)
908-
assert jax_fn(10).shape == (10,)
909-
910-
911-
def test_constant_shape_after_graph_rewriting():
912-
size = pt.vector("size", shape=(2,), dtype=int)
913-
x = pt.random.normal(size=size)
914-
assert x.type.shape == (None, None)
915-
916-
with pytest.raises(TypeError):
917-
compile_random_function([size], x)([2, 5])
918-
919-
# Rebuild with strict=False so output type is not updated
920-
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
921-
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
922-
assert new_x.type.shape == (None, None)
923-
assert compile_random_function([], new_x)().shape == (2, 5)
924-
925-
# Rebuild with strict=True, so output type is updated
926-
# This uses a different path in the dispatch implementation
927-
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
928-
assert new_x.type.shape == (2, 5)
929-
assert compile_random_function([], new_x)().shape == (2, 5)
839+
class TestRandomShapeInputs:
840+
def test_random_concrete_shape(self):
841+
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
842+
843+
There are three quantities that JAX considers as concrete:
844+
1. Constants known at compile time;
845+
2. The shape of an array.
846+
3. `static_argnums` parameters
847+
This test makes sure that graphs with `RandomVariable`s compile when the
848+
`size` parameter satisfies either of these criteria.
849+
850+
"""
851+
rng = shared(np.random.default_rng(123))
852+
x_pt = pt.dmatrix()
853+
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
854+
jax_fn = compile_random_function([x_pt], out)
855+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
856+
857+
def test_random_concrete_shape_from_param(self):
858+
rng = shared(np.random.default_rng(123))
859+
x_pt = pt.dmatrix()
860+
out = pt.random.normal(x_pt, 1, rng=rng)
861+
jax_fn = compile_random_function([x_pt], out)
862+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
863+
864+
def test_random_concrete_shape_subtensor(self):
865+
"""JAX should compile when a concrete value is passed for the `size` parameter.
866+
867+
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
868+
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
869+
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
870+
rewrite.
871+
872+
JAX does not accept scalars as `size` or `shape` arguments, so this is a
873+
slight improvement over their API.
874+
875+
"""
876+
rng = shared(np.random.default_rng(123))
877+
x_pt = pt.dmatrix()
878+
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
879+
jax_fn = compile_random_function([x_pt], out)
880+
assert jax_fn(np.ones((2, 3))).shape == (3,)
881+
882+
def test_random_concrete_shape_subtensor_tuple(self):
883+
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
884+
885+
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
886+
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
887+
scalar inputs into tuples of concrete values using the
888+
`jax_size_parameter_as_tuple` rewrite.
889+
890+
"""
891+
rng = shared(np.random.default_rng(123))
892+
x_pt = pt.dmatrix()
893+
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
894+
jax_fn = compile_random_function([x_pt], out)
895+
assert jax_fn(np.ones((2, 3))).shape == (2,)
896+
897+
@pytest.mark.xfail(
898+
reason="`size_pt` should be specified as a static argument", strict=True
899+
)
900+
def test_random_concrete_shape_graph_input(self):
901+
rng = shared(np.random.default_rng(123))
902+
size_pt = pt.scalar()
903+
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
904+
jax_fn = compile_random_function([size_pt], out)
905+
assert jax_fn(10).shape == (10,)
906+
907+
def test_constant_shape_after_graph_rewriting(self):
908+
size = pt.vector("size", shape=(2,), dtype=int)
909+
x = pt.random.normal(size=size)
910+
assert x.type.shape == (None, None)
911+
912+
with pytest.raises(TypeError):
913+
compile_random_function([size], x)([2, 5])
914+
915+
# Rebuild with strict=False so output type is not updated
916+
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
917+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
918+
assert new_x.type.shape == (None, None)
919+
assert compile_random_function([], new_x)().shape == (2, 5)
920+
921+
# Rebuild with strict=True, so output type is updated
922+
# This uses a different path in the dispatch implementation
923+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
924+
assert new_x.type.shape == (2, 5)
925+
assert compile_random_function([], new_x)().shape == (2, 5)

0 commit comments

Comments
 (0)