Skip to content

Commit 05e39e1

Browse files
ricardoV94twiecki
authored andcommitted
Remove internal uses of patternbroadcast and unbroadcast
1 parent e25e42d commit 05e39e1

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

pymc/data.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def __init__(
310310
batch_size=128,
311311
dtype=None,
312312
broadcastable=None,
313+
shape=None,
313314
name="Minibatch",
314315
random_seed=42,
315316
update_shared_f=None,
@@ -324,9 +325,15 @@ def __init__(
324325
self.update_shared_f = update_shared_f
325326
self.random_slc = self.make_random_slices(self.shared.shape, batch_size, random_seed)
326327
minibatch = self.shared[self.random_slc]
327-
if broadcastable is None:
328-
broadcastable = (False,) * minibatch.ndim
329-
minibatch = at.patternbroadcast(minibatch, broadcastable)
328+
if broadcastable is not None:
329+
warnings.warn(
330+
"Minibatch `broadcastable` argument is deprecated. Use `shape` instead",
331+
FutureWarning,
332+
)
333+
assert shape is None
334+
shape = [1 if b else None for b in broadcastable]
335+
if shape is not None:
336+
minibatch = at.specify_shape(minibatch, shape)
330337
self.minibatch = minibatch
331338
super().__init__(self.minibatch.type, None, None, name=name)
332339
Apply(aesara.compile.view_op, inputs=[self.minibatch], outputs=[self])

pymc/ode/ode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ def __call__(self, y0, theta, return_sens=False, **kwargs):
158158
)
159159

160160
# convert inputs to tensors (and check their types)
161-
y0 = at.cast(at.unbroadcast(at.as_tensor_variable(y0), 0), floatX)
162-
theta = at.cast(at.unbroadcast(at.as_tensor_variable(theta), 0), floatX)
161+
y0 = at.cast(at.as_tensor_variable(y0), floatX)
162+
theta = at.cast(at.as_tensor_variable(theta), floatX)
163163
inputs = [y0, theta]
164164
for i, (input_val, itype) in enumerate(zip(inputs, self._itypes)):
165-
if not input_val.type.in_same_class(itype):
165+
if not itype.is_super(input_val.type):
166166
raise ValueError(
167167
f"Input {i} of type {input_val.type} does not have the expected type of {itype}"
168168
)

pymc/variational/opvi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def symbolic_sample_over_posterior(self, node):
10301030
"""
10311031
node = self.to_flat_input(node)
10321032
random = self.symbolic_random.astype(self.symbolic_initial.dtype)
1033-
random = at.patternbroadcast(random, self.symbolic_initial.broadcastable)
1033+
random = at.specify_shape(random, self.symbolic_initial.type.shape)
10341034

10351035
def sample(post, node):
10361036
return aesara.clone_replace(node, {self.input: post})
@@ -1065,7 +1065,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
10651065
dict with replacements for initial
10661066
"""
10671067
initial = self._new_initial(s, d, more_replacements)
1068-
initial = at.patternbroadcast(initial, self.symbolic_initial.broadcastable)
1068+
initial = at.specify_shape(initial, self.symbolic_initial.type.shape)
10691069
if more_replacements:
10701070
initial = aesara.clone_replace(initial, more_replacements)
10711071
return {self.symbolic_initial: initial}

0 commit comments

Comments
 (0)