@@ -563,24 +563,24 @@ def join_nonshared_inputs(
563
563
raise ValueError ("Empty list of input variables." )
564
564
565
565
raveled_inputs = pt .concatenate ([var .ravel () for var in inputs ])
566
+ input_sizes = [point [var_name ].size for var_name in point ]
567
+ size = sum (input_sizes )
566
568
567
569
if not make_inputs_shared :
568
- tensor_type = raveled_inputs .type
569
- joined_inputs = tensor_type ("joined_inputs" )
570
+ joined_inputs = pt .tensor ("joined_inputs" , shape = (size ,), dtype = raveled_inputs .dtype )
570
571
else :
571
572
joined_values = np .concatenate ([point [var .name ].ravel () for var in inputs ])
572
- joined_inputs = pytensor .shared (joined_values , "joined_inputs" )
573
+ joined_inputs = pytensor .shared (joined_values , "joined_inputs" , shape = ( size ,) )
573
574
574
575
if pytensor .config .compute_test_value != "off" :
575
576
joined_inputs .tag .test_value = raveled_inputs .tag .test_value
576
577
577
578
replace : dict [TensorVariable , TensorVariable ] = {}
578
- last_idx = 0
579
- for var in inputs :
579
+ for var , flat_var in zip (
580
+ inputs , pt .split (joined_inputs , input_sizes , len (inputs )), strict = True
581
+ ):
580
582
shape = point [var .name ].shape
581
- arr_len = np .prod (shape , dtype = int )
582
- replace [var ] = joined_inputs [last_idx : last_idx + arr_len ].reshape (shape ).astype (var .dtype )
583
- last_idx += arr_len
583
+ replace [var ] = flat_var .reshape (shape ).astype (var .dtype )
584
584
585
585
if shared_inputs is not None :
586
586
replace .update (shared_inputs )
0 commit comments