Skip to content

Commit e76bba9

Browse files
committed
Fix bug in transform_scan_values
1 parent 534a9ae commit e76bba9

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

pymc/logprob/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def transform_scan_values(fgraph: FunctionGraph, node: Node) -> Optional[List[No
237237
return None
238238

239239
transforms = [
240-
values_to_transforms.get(rv_map_feature.original_values[value], None)
240+
values_to_transforms.get(rv_map_feature.original_values[value_var], None)
241241
for value_var in value_vars
242242
]
243243

tests/logprob/test_transforms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,8 +914,13 @@ def test_scan_transform():
914914
init = at.random.beta(1, 1, name="init")
915915
init_vv = init.clone()
916916

917+
def scan_step(prev_innov):
918+
next_innov = at.random.beta(prev_innov * 10, (1 - prev_innov) * 10)
919+
update = {next_innov.owner.inputs[0]: next_innov.owner.outputs[0]}
920+
return next_innov, update
921+
917922
innov, _ = scan(
918-
fn=lambda prev_innov: at.random.beta(prev_innov * 10, (1 - prev_innov) * 10),
923+
fn=scan_step,
919924
outputs_info=[init],
920925
n_steps=4,
921926
)

0 commit comments

Comments
 (0)