Skip to content

Commit 4d0a421

Browse files
ricardoV94twiecki
authored andcommitted
Fix _check_start_shape
1 parent fb8d38b commit 4d0a421

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

pymc3/sampling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,11 @@ def sample(
684684
def _check_start_shape(model, start):
685685
if not isinstance(start, dict):
686686
raise TypeError("start argument must be a dict or an array-like of dicts")
687+
688+
# Filter "non-input" variables
689+
initial_point = model.initial_point
690+
start = {k: v for k, v in start.items() if k in initial_point}
691+
687692
e = ""
688693
for var in model.basic_RVs:
689694
var_shape = model.fastfn(var.shape)(start)

0 commit comments

Comments
 (0)