|
6 | 6 | import jax.numpy as jnp
|
7 | 7 | import jax.scipy as jsp
|
8 | 8 | import numpy as np
|
9 |
| -from numpy.random import RandomState |
| 9 | +from numpy.random import Generator, RandomState |
| 10 | +from numpy.random.bit_generator import _coerce_to_uint32_array |
10 | 11 |
|
11 | 12 | from aesara.compile.ops import DeepCopyOp, ViewOp
|
12 | 13 | from aesara.configdefaults import config
|
@@ -105,6 +106,33 @@ def jax_typify_ndarray(data, dtype=None, **kwargs):
|
105 | 106 | def jax_typify_RandomState(state, **kwargs):
|
106 | 107 | state = state.get_state(legacy=False)
|
107 | 108 | state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
|
| 109 | + # XXX: Is this a reasonable approach? |
| 110 | + state["jax_state"] = state["state"]["key"][0:2] |
| 111 | + return state |
| 112 | + |
| 113 | + |
| 114 | +@jax_typify.register(Generator) |
| 115 | +def jax_typify_Generator(rng, **kwargs): |
| 116 | + state = rng.__getstate__() |
| 117 | + state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] |
| 118 | + |
| 119 | + # XXX: Is this a reasonable approach? |
| 120 | + state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2] |
| 121 | + |
| 122 | + # The "state" and "inc" values in a NumPy `Generator` are 128 bits, which |
| 123 | + # JAX can't handle, so we split these values into arrays of 32 bit integers |
| 124 | + # and then combine the first two into a single 64 bit integers. |
| 125 | + # |
| 126 | + # XXX: Depending on how we expect these values to be used, is this approach |
| 127 | + # reasonable? |
| 128 | + # |
| 129 | + # TODO: We might as well remove these altogether, since this conversion |
| 130 | + # should only occur once (e.g. when the graph is converted/JAX-compiled), |
| 131 | + # and, from then on, we use the custom "jax_state" value. |
| 132 | + inc_32 = _coerce_to_uint32_array(state["state"]["inc"]) |
| 133 | + state_32 = _coerce_to_uint32_array(state["state"]["state"]) |
| 134 | + state["state"]["inc"] = inc_32[0] << 32 | inc_32[1] |
| 135 | + state["state"]["state"] = state_32[0] << 32 | state_32[1] |
108 | 136 | return state
|
109 | 137 |
|
110 | 138 |
|
@@ -999,21 +1027,23 @@ def batched_dot(a, b):
|
999 | 1027 |
|
1000 | 1028 |
|
1001 | 1029 | @jax_funcify.register(RandomVariable)
|
1002 |
| -def jax_funcify_RandomVariable(op, **kwargs): |
| 1030 | +def jax_funcify_RandomVariable(op, node, **kwargs): |
1003 | 1031 | name = op.name
|
1004 | 1032 |
|
1005 | 1033 | if not hasattr(jax.random, name):
|
1006 | 1034 | raise NotImplementedError(
|
1007 | 1035 | f"No JAX conversion for the given distribution: {name}"
|
1008 | 1036 | )
|
1009 | 1037 |
|
1010 |
| - def random_variable(rng, size, dtype, *args): |
1011 |
| - prng = jax.random.PRNGKey(rng["state"]["key"][0]) |
1012 |
| - dtype = jnp.dtype(dtype) |
| 1038 | + dtype = node.outputs[1].dtype |
| 1039 | + |
| 1040 | + def random_variable(rng, size, dtype_num, *args): |
| 1041 | + if not op.inplace: |
| 1042 | + rng = rng.copy() |
| 1043 | + prng = rng["jax_state"] |
1013 | 1044 | data = getattr(jax.random, name)(key=prng, shape=size)
|
1014 | 1045 | smpl_value = jnp.array(data, dtype=dtype)
|
1015 |
| - prng = jax.random.split(prng, num=1)[0] |
1016 |
| - jax.ops.index_update(rng["state"]["key"], 0, prng[0]) |
| 1046 | + rng["jax_state"] = jax.random.split(prng, num=1)[0] |
1017 | 1047 | return (rng, smpl_value)
|
1018 | 1048 |
|
1019 | 1049 | return random_variable
|
0 commit comments