1
- import jax
2
-
3
1
from pytensor .link .jax .dispatch .basic import jax_funcify
4
2
from pytensor .tensor .subtensor import (
5
3
AdvancedIncSubtensor ,
33
31
"""
34
32
35
33
36
- def assert_indices_jax_compatible (node , idx_list ):
34
+ def subtensor_assert_indices_jax_compatible (node , idx_list ):
37
35
from pytensor .graph .basic import Constant
38
36
from pytensor .tensor .var import TensorVariable
39
37
@@ -55,7 +53,7 @@ def assert_indices_jax_compatible(node, idx_list):
55
53
def jax_funcify_Subtensor (op , node , ** kwargs ):
56
54
57
55
idx_list = getattr (op , "idx_list" , None )
58
- assert_indices_jax_compatible (node , idx_list )
56
+ subtensor_assert_indices_jax_compatible (node , idx_list )
59
57
60
58
def subtensor_constant (x , * ilists ):
61
59
indices = indices_from_subtensor (ilists , idx_list )
@@ -69,25 +67,19 @@ def subtensor_constant(x, *ilists):
69
67
70
68
@jax_funcify .register (IncSubtensor )
71
69
@jax_funcify .register (AdvancedIncSubtensor1 )
72
- def jax_funcify_IncSubtensor (op , ** kwargs ):
70
+ def jax_funcify_IncSubtensor (op , node , ** kwargs ):
73
71
74
72
idx_list = getattr (op , "idx_list" , None )
75
73
76
74
if getattr (op , "set_instead_of_inc" , False ):
77
- jax_fn = getattr (jax .ops , "index_update" , None )
78
-
79
- if jax_fn is None :
80
75
81
- def jax_fn (x , indices , y ):
82
- return x .at [indices ].set (y )
76
+ def jax_fn (x , indices , y ):
77
+ return x .at [indices ].set (y )
83
78
84
79
else :
85
- jax_fn = getattr (jax .ops , "index_add" , None )
86
-
87
- if jax_fn is None :
88
80
89
- def jax_fn (x , indices , y ):
90
- return x .at [indices ].add (y )
81
+ def jax_fn (x , indices , y ):
82
+ return x .at [indices ].add (y )
91
83
92
84
def incsubtensor (x , y , * ilist , jax_fn = jax_fn , idx_list = idx_list ):
93
85
indices = indices_from_subtensor (ilist , idx_list )
@@ -100,23 +92,17 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
100
92
101
93
102
94
@jax_funcify .register (AdvancedIncSubtensor )
103
- def jax_funcify_AdvancedIncSubtensor (op , ** kwargs ):
95
+ def jax_funcify_AdvancedIncSubtensor (op , node , ** kwargs ):
104
96
105
97
if getattr (op , "set_instead_of_inc" , False ):
106
- jax_fn = getattr (jax .ops , "index_update" , None )
107
98
108
- if jax_fn is None :
109
-
110
- def jax_fn (x , indices , y ):
111
- return x .at [indices ].set (y )
99
+ def jax_fn (x , indices , y ):
100
+ return x .at [indices ].set (y )
112
101
113
102
else :
114
- jax_fn = getattr (jax .ops , "index_add" , None )
115
-
116
- if jax_fn is None :
117
103
118
- def jax_fn (x , indices , y ):
119
- return x .at [indices ].add (y )
104
+ def jax_fn (x , indices , y ):
105
+ return x .at [indices ].add (y )
120
106
121
107
def advancedincsubtensor (x , y , * ilist , jax_fn = jax_fn ):
122
108
return jax_fn (x , ilist , y )
0 commit comments