Skip to content

Commit 4235ccc

Browse files
rloufricardoV94
authored andcommitted
Typify 0-dim arrays to corresponding number
1 parent 0087e56 commit 4235ccc

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytensor/link/jax/dispatch/basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def jax_typify(data, dtype=None, **kwargs):
3030

3131
@jax_typify.register(np.ndarray)
3232
def jax_typify_ndarray(data, dtype=None, **kwargs):
33+
if len(data.shape) == 0:
34+
return data.item()
3335
return jnp.array(data, dtype=dtype)
3436

3537

0 commit comments

Comments
 (0)