Skip to content

Commit 4954fd4

Browse files
authored
Remove reshape_t (#6118)
* Remove reshape_t function * Remove reshape_t reference
1 parent 08f7a89 commit 4954fd4

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

pymc/aesaraf.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def join_nonshared_inputs(
572572
for var in vars:
573573
shape = point[var.name].shape
574574
arr_len = np.prod(shape, dtype=int)
575-
replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], shape).astype(var.dtype)
575+
replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
576576
last_idx += arr_len
577577

578578
replace.update(shared)
@@ -581,14 +581,6 @@ def join_nonshared_inputs(
581581
return xs_special, inarray
582582

583583

584-
def reshape_t(x, shape):
585-
"""Work around fact that x.reshape(()) doesn't work"""
586-
if shape != ():
587-
return x.reshape(shape)
588-
else:
589-
return x[0]
590-
591-
592584
class PointFunc:
593585
"""Wraps so a function so it takes a dict of arguments instead of arguments."""
594586

0 commit comments

Comments
 (0)