@@ -99,26 +99,27 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
99
99
:py:func:`pymc.util.get_random_generator` for more information.
100
100
"""
101
101
self .vars = vars
102
+ self .var_names = tuple (cast (str , var .name ) for var in vars )
102
103
self .shared = {get_var_name (var ): shared for var , shared in shared .items ()}
103
104
self .blocked = blocked
104
105
self .rng = get_random_generator (rng )
105
106
106
107
def step (self , point : PointType ) -> tuple [PointType , StatsType ]:
107
- for name , shared_var in self .shared .items ():
108
- shared_var .set_value (point [name ])
109
-
110
- var_dict = {cast (str , v .name ): point [cast (str , v .name )] for v in self .vars }
111
- q = DictToArrayBijection .map (var_dict )
112
-
108
+ full_point = None
109
+ if self .shared :
110
+ for name , shared_var in self .shared .items ():
111
+ shared_var .set_value (point [name ], borrow = True )
112
+ full_point = point
113
+ point = {name : point [name ] for name in self .var_names }
114
+
115
+ q = DictToArrayBijection .map (point )
113
116
apoint , stats = self .astep (q )
114
117
115
118
if not isinstance (apoint , RaveledVars ):
116
119
# We assume that the mapping has stayed the same
117
120
apoint = RaveledVars (apoint , q .point_map_info )
118
121
119
- new_point = DictToArrayBijection .rmap (apoint , start_point = point )
120
-
121
- return new_point , stats
122
+ return DictToArrayBijection .rmap (apoint , start_point = full_point ), stats
122
123
123
124
@abstractmethod
124
125
def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
0 commit comments