Skip to content

Commit 36b9e1e

Browse files
committed
Use static shape in join_nonshared_inputs
1 parent 19389fc commit 36b9e1e

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pymc/pytensorf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -563,24 +563,24 @@ def join_nonshared_inputs(
563563
raise ValueError("Empty list of input variables.")
564564

565565
raveled_inputs = pt.concatenate([var.ravel() for var in inputs])
566+
input_sizes = [point[var_name].size for var_name in point]
567+
size = sum(input_sizes)
566568

567569
if not make_inputs_shared:
568-
tensor_type = raveled_inputs.type
569-
joined_inputs = tensor_type("joined_inputs")
570+
joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype)
570571
else:
571572
joined_values = np.concatenate([point[var.name].ravel() for var in inputs])
572-
joined_inputs = pytensor.shared(joined_values, "joined_inputs")
573+
joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,))
573574

574575
if pytensor.config.compute_test_value != "off":
575576
joined_inputs.tag.test_value = raveled_inputs.tag.test_value
576577

577578
replace: dict[TensorVariable, TensorVariable] = {}
578-
last_idx = 0
579-
for var in inputs:
579+
for var, flat_var in zip(
580+
inputs, pt.split(joined_inputs, input_sizes, len(inputs)), strict=True
581+
):
580582
shape = point[var.name].shape
581-
arr_len = np.prod(shape, dtype=int)
582-
replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
583-
last_idx += arr_len
583+
replace[var] = flat_var.reshape(shape).astype(var.dtype)
584584

585585
if shared_inputs is not None:
586586
replace.update(shared_inputs)

0 commit comments

Comments
 (0)