We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 08f7a89 commit 4954fd4Copy full SHA for 4954fd4
pymc/aesaraf.py
@@ -572,7 +572,7 @@ def join_nonshared_inputs(
572
for var in vars:
573
shape = point[var.name].shape
574
arr_len = np.prod(shape, dtype=int)
575
- replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], shape).astype(var.dtype)
+ replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
576
last_idx += arr_len
577
578
replace.update(shared)
@@ -581,14 +581,6 @@ def join_nonshared_inputs(
581
return xs_special, inarray
582
583
584
-def reshape_t(x, shape):
585
- """Work around fact that x.reshape(()) doesn't work"""
586
- if shape != ():
587
- return x.reshape(shape)
588
- else:
589
- return x[0]
590
-
591
592
class PointFunc:
593
"""Wraps so a function so it takes a dict of arguments instead of arguments."""
594
0 commit comments