Skip to content

Commit 92eef5e

Browse files
committed
Allow running JAX functions with scalar inputs for RV shapes
1 parent 4cdd290 commit 92eef5e

File tree

2 files changed

+87
-10
lines changed

2 files changed

+87
-10
lines changed

pytensor/link/jax/linker.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
class JAXLinker(JITLinker):
1010
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
1111

12+
def __init__(self, *args, **kwargs):
13+
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
14+
super().__init__(*args, **kwargs)
15+
1216
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1317
from pytensor.link.jax.dispatch import jax_funcify
18+
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1419
from pytensor.tensor.random.type import RandomType
1520

1621
shared_rng_inputs = [
@@ -64,14 +69,46 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
6469
fgraph.inputs.remove(new_inp)
6570
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
6671

72+
fgraph_inputs = fgraph.inputs
73+
clients = fgraph.clients
74+
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
75+
scalar_shape_inputs = [
76+
inp
77+
for node in fgraph.apply_nodes
78+
if isinstance(node.op, JAXShapeTuple)
79+
for inp in node.inputs
80+
if inp in fgraph_inputs
81+
and all(
82+
isinstance(cl_node.op, JAXShapeTuple) for cl_node, _ in clients[inp]
83+
)
84+
]
85+
self.scalar_shape_inputs = tuple(
86+
fgraph_inputs.index(inp) for inp in scalar_shape_inputs
87+
)
88+
6789
return jax_funcify(
6890
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
6991
)
7092

7193
def jit_compile(self, fn):
7294
import jax
7395

74-
return jax.jit(fn)
96+
jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs)
97+
98+
if not self.scalar_shape_inputs:
99+
return jit_fn
100+
101+
def convert_scalar_shape_inputs(
102+
*args, scalar_shape_inputs=set(self.scalar_shape_inputs)
103+
):
104+
return jit_fn(
105+
*(
106+
int(arg) if i in scalar_shape_inputs else arg
107+
for i, arg in enumerate(args)
108+
)
109+
)
110+
111+
return convert_scalar_shape_inputs
75112

76113
def create_thunk_inputs(self, storage_map):
77114
from pytensor.link.jax.dispatch import jax_typify

tests/link/jax/test_random.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -894,15 +894,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
894894
jax_fn = compile_random_function([x_pt], out)
895895
assert jax_fn(np.ones((2, 3))).shape == (2,)
896896

897+
def test_random_scalar_shape_input(self):
898+
dim0 = pt.scalar("dim0", dtype=int)
899+
dim1 = pt.scalar("dim1", dtype=int)
900+
901+
out = pt.random.normal(0, 1, size=dim0)
902+
jax_fn = compile_random_function([dim0], out)
903+
assert jax_fn(np.array(2)).shape == (2,)
904+
assert jax_fn(np.array(3)).shape == (3,)
905+
906+
out = pt.random.normal(0, 1, size=[dim0, dim1])
907+
jax_fn = compile_random_function([dim0, dim1], out)
908+
assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
909+
assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)
910+
897911
@pytest.mark.xfail(
898-
reason="`size_pt` should be specified as a static argument", strict=True
912+
raises=TypeError, reason="Cannot convert scalar input to integer"
899913
)
900-
def test_random_concrete_shape_graph_input(self):
901-
rng = shared(np.random.default_rng(123))
902-
size_pt = pt.scalar()
903-
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
904-
jax_fn = compile_random_function([size_pt], out)
905-
assert jax_fn(10).shape == (10,)
914+
def test_random_scalar_shape_input_not_supported(self):
915+
dim = pt.scalar("dim", dtype=int)
916+
out1 = pt.random.normal(0, 1, size=dim)
917+
# An operation that wouldn't work if we replaced 0d array by integer
918+
out2 = dim[...].set(1)
919+
jax_fn = compile_random_function([dim], [out1, out2])
920+
921+
res1, res2 = jax_fn(np.array(2))
922+
assert res1.shape == (2,)
923+
assert res2 == 1
924+
925+
@pytest.mark.xfail(
926+
raises=TypeError, reason="Cannot convert scalar input to integer"
927+
)
928+
def test_random_scalar_shape_input_not_supported2(self):
929+
dim = pt.scalar("dim", dtype=int)
930+
# This could theoretically be supported
931+
# but would require knowing that * 2 is a safe operation for a python integer
932+
out = pt.random.normal(0, 1, size=dim * 2)
933+
jax_fn = compile_random_function([dim], out)
934+
assert jax_fn(np.array(2)).shape == (4,)
935+
936+
@pytest.mark.xfail(
937+
raises=TypeError, reason="Cannot convert tensor input to shape tuple"
938+
)
939+
def test_random_vector_shape_graph_input(self):
940+
shape = pt.vector("shape", shape=(2,), dtype=int)
941+
out = pt.random.normal(0, 1, size=shape)
942+
943+
jax_fn = compile_random_function([shape], out)
944+
assert jax_fn(np.array([2, 3])).shape == (2, 3)
945+
assert jax_fn(np.array([4, 5])).shape == (4, 5)
906946

907947
def test_constant_shape_after_graph_rewriting(self):
908948
size = pt.vector("size", shape=(2,), dtype=int)
@@ -912,13 +952,13 @@ def test_constant_shape_after_graph_rewriting(self):
912952
with pytest.raises(TypeError):
913953
compile_random_function([size], x)([2, 5])
914954

915-
# Rebuild with strict=False so output type is not updated
955+
# Rebuild with strict=True so output type is not updated
916956
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
917957
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
918958
assert new_x.type.shape == (None, None)
919959
assert compile_random_function([], new_x)().shape == (2, 5)
920960

921-
# Rebuild with strict=True, so output type is updated
961+
# Rebuild with strict=False, so output type is updated
922962
# This uses a different path in the dispatch implementation
923963
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
924964
assert new_x.type.shape == (2, 5)

0 commit comments

Comments
 (0)