Skip to content

Commit 8708f21

Browse files
committed
Optimize ArrayStepShared.step
1 parent 812d985 commit 8708f21

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

pymc/step_methods/arraystep.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,26 +99,27 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
9999
:py:func:`pymc.util.get_random_generator` for more information.
100100
"""
101101
self.vars = vars
102+
self.var_names = tuple(cast(str, var.name) for var in vars)
102103
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
103104
self.blocked = blocked
104105
self.rng = get_random_generator(rng)
105106

106107
def step(self, point: PointType) -> tuple[PointType, StatsType]:
107-
for name, shared_var in self.shared.items():
108-
shared_var.set_value(point[name])
109-
110-
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
111-
q = DictToArrayBijection.map(var_dict)
112-
108+
full_point = None
109+
if self.shared:
110+
for name, shared_var in self.shared.items():
111+
shared_var.set_value(point[name], borrow=True)
112+
full_point = point
113+
point = {name: point[name] for name in self.var_names}
114+
115+
q = DictToArrayBijection.map(point)
113116
apoint, stats = self.astep(q)
114117

115118
if not isinstance(apoint, RaveledVars):
116119
# We assume that the mapping has stayed the same
117120
apoint = RaveledVars(apoint, q.point_map_info)
118121

119-
new_point = DictToArrayBijection.rmap(apoint, start_point=point)
120-
121-
return new_point, stats
122+
return DictToArrayBijection.rmap(apoint, start_point=full_point), stats
122123

123124
@abstractmethod
124125
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:

0 commit comments

Comments
 (0)