Skip to content

Commit 2c91b5a

Browse files
kc611brandonwillard
authored andcommitted
Add support for NumPy Generator types in JAX backend
1 parent 5611cf7 commit 2c91b5a

File tree

3 files changed

+70
-23
lines changed

3 files changed

+70
-23
lines changed

aesara/link/jax/dispatch.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import jax.numpy as jnp
77
import jax.scipy as jsp
88
import numpy as np
9-
from numpy.random import RandomState
9+
from numpy.random import Generator, RandomState
10+
from numpy.random.bit_generator import _coerce_to_uint32_array
1011

1112
from aesara.compile.ops import DeepCopyOp, ViewOp
1213
from aesara.configdefaults import config
@@ -105,6 +106,33 @@ def jax_typify_ndarray(data, dtype=None, **kwargs):
105106
def jax_typify_RandomState(state, **kwargs):
106107
state = state.get_state(legacy=False)
107108
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
109+
# XXX: Is this a reasonable approach?
110+
state["jax_state"] = state["state"]["key"][0:2]
111+
return state
112+
113+
114+
@jax_typify.register(Generator)
115+
def jax_typify_Generator(rng, **kwargs):
116+
state = rng.__getstate__()
117+
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
118+
119+
# XXX: Is this a reasonable approach?
120+
state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2]
121+
122+
# The "state" and "inc" values in a NumPy `Generator` are 128 bits, which
123+
# JAX can't handle, so we split these values into arrays of 32 bit integers
124+
# and then combine the first two into a single 64 bit integers.
125+
#
126+
# XXX: Depending on how we expect these values to be used, is this approach
127+
# reasonable?
128+
#
129+
# TODO: We might as well remove these altogether, since this conversion
130+
# should only occur once (e.g. when the graph is converted/JAX-compiled),
131+
# and, from then on, we use the custom "jax_state" value.
132+
inc_32 = _coerce_to_uint32_array(state["state"]["inc"])
133+
state_32 = _coerce_to_uint32_array(state["state"]["state"])
134+
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
135+
state["state"]["state"] = state_32[0] << 32 | state_32[1]
108136
return state
109137

110138

@@ -999,21 +1027,23 @@ def batched_dot(a, b):
9991027

10001028

10011029
@jax_funcify.register(RandomVariable)
1002-
def jax_funcify_RandomVariable(op, **kwargs):
1030+
def jax_funcify_RandomVariable(op, node, **kwargs):
10031031
name = op.name
10041032

10051033
if not hasattr(jax.random, name):
10061034
raise NotImplementedError(
10071035
f"No JAX conversion for the given distribution: {name}"
10081036
)
10091037

1010-
def random_variable(rng, size, dtype, *args):
1011-
prng = jax.random.PRNGKey(rng["state"]["key"][0])
1012-
dtype = jnp.dtype(dtype)
1038+
dtype = node.outputs[1].dtype
1039+
1040+
def random_variable(rng, size, dtype_num, *args):
1041+
if not op.inplace:
1042+
rng = rng.copy()
1043+
prng = rng["jax_state"]
10131044
data = getattr(jax.random, name)(key=prng, shape=size)
10141045
smpl_value = jnp.array(data, dtype=dtype)
1015-
prng = jax.random.split(prng, num=1)[0]
1016-
jax.ops.index_update(rng["state"]["key"], 0, prng[0])
1046+
rng["jax_state"] = jax.random.split(prng, num=1)[0]
10171047
return (rng, smpl_value)
10181048

10191049
return random_variable

aesara/link/jax/linker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from numpy.random import RandomState
1+
from numpy.random import Generator, RandomState
22

33
from aesara.graph.basic import Constant
44
from aesara.link.basic import JITLinker
@@ -28,7 +28,7 @@ def create_thunk_inputs(self, storage_map):
2828
thunk_inputs = []
2929
for n in self.fgraph.inputs:
3030
sinput = storage_map[n]
31-
if isinstance(sinput[0], RandomState):
31+
if isinstance(sinput[0], (RandomState, Generator)):
3232
new_value = jax_typify(
3333
sinput[0], dtype=getattr(sinput[0], "dtype", None)
3434
)

tests/link/test_jax.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,12 +1188,38 @@ def test_extra_ops_omni():
11881188
compare_jax_and_py(fgraph, [])
11891189

11901190

1191-
@pytest.mark.xfail(reason="The RNG states are not 1:1", raises=AssertionError)
1192-
def test_random():
1193-
rng = shared(np.random.RandomState(123))
1194-
out = normal(rng=rng)
1191+
@pytest.mark.parametrize(
1192+
"at_dist, dist_params, rng, size",
1193+
[
1194+
(
1195+
normal,
1196+
(),
1197+
shared(np.random.RandomState(123)),
1198+
10000,
1199+
),
1200+
(
1201+
normal,
1202+
(),
1203+
shared(np.random.default_rng(123)),
1204+
10000,
1205+
),
1206+
],
1207+
)
1208+
def test_random_stats(at_dist, dist_params, rng, size):
1209+
# The RNG states are not 1:1, so the best we can do is check some summary
1210+
# statistics of the samples
1211+
out = normal(*dist_params, rng=rng, size=size)
11951212
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
1196-
compare_jax_and_py(fgraph, [])
1213+
1214+
def assert_fn(x, y):
1215+
(x,) = x
1216+
(y,) = y
1217+
assert x.dtype.kind == y.dtype.kind
1218+
1219+
d = 2 if config.floatX == "float64" else 1
1220+
np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)
1221+
1222+
compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
11971223

11981224

11991225
def test_random_unimplemented():
@@ -1218,7 +1244,6 @@ def rng_fn(cls, rng, size):
12181244
compare_jax_and_py(fgraph, [])
12191245

12201246

1221-
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
12221247
def test_RandomStream():
12231248
srng = RandomStream(seed=123)
12241249
out = srng.normal() - srng.normal()
@@ -1228,11 +1253,3 @@ def test_RandomStream():
12281253
jax_res_2 = fn()
12291254

12301255
assert np.array_equal(jax_res_1, jax_res_2)
1231-
1232-
1233-
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
1234-
def test_random_generators():
1235-
rng = shared(np.random.default_rng(123))
1236-
out = normal(rng=rng)
1237-
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
1238-
compare_jax_and_py(fgraph, [])

0 commit comments

Comments
 (0)