Skip to content

Commit 1d928af

Browse files
committed
Bump PyMC dependency
1 parent ae7b80f commit 1d928af

File tree

5 files changed

+26
-35
lines changed

5 files changed

+26
-35
lines changed

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ dependencies:
99
- dask
1010
- xhistogram
1111
- pip:
12-
- pymc>=5.2.0 # CI was failing to resolve
12+
- pymc>=5.4.1 # CI was failing to resolve
1313
- blackjax

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ dependencies:
99
- dask
1010
- xhistogram
1111
- pip:
12-
- pymc>=5.2.0 # CI was failing to resolve
12+
- pymc>=5.4.1 # CI was failing to resolve

pymc_experimental/marginal_model.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -395,44 +395,38 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
395395
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
396396
# Clone the inner RV graph of the Marginalized RV
397397
marginalized_rvs_node = op.make_node(*inputs)
398-
marginalized_rv, *dependent_rvs = clone_replace(
398+
inner_rvs = clone_replace(
399399
op.inner_outputs,
400400
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
401401
)
402+
marginalized_rv = inner_rvs[0]
402403

403404
# Obtain the joint_logp graph of the inner RV graph
404-
# Some inputs are not root inputs (such as transformed projections of value variables)
405-
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
406-
inputs = list(inputvars(inputs))
407-
rvs_to_values = {}
408-
dummy_marginalized_value = marginalized_rv.clone()
409-
rvs_to_values[marginalized_rv] = dummy_marginalized_value
410-
rvs_to_values.update(zip(dependent_rvs, values))
411-
logps_dict = factorized_joint_logprob(rv_values=rvs_to_values, **kwargs)
405+
inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs}
406+
logps_dict = factorized_joint_logprob(rv_values=inner_rvs_to_values, **kwargs)
412407

413408
# Reduce logp dimensions corresponding to broadcasted variables
414-
values_axis_bcast = []
415-
for value in values:
416-
vbcast = value.type.broadcastable
417-
mbcast = dummy_marginalized_value.type.broadcastable
409+
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]
410+
for inner_rv, inner_value in inner_rvs_to_values.items():
411+
if inner_rv is marginalized_rv:
412+
continue
413+
vbcast = inner_value.type.broadcastable
414+
mbcast = marginalized_rv.type.broadcastable
418415
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast
419-
values_axis_bcast.append([i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v])
420-
joint_logp = logps_dict[dummy_marginalized_value]
421-
for value, values_axis_bcast in zip(values, values_axis_bcast):
422-
joint_logp += logps_dict[value].sum(values_axis_bcast, keepdims=True)
416+
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
417+
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)
423418

424419
# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
425420
# values of the marginalized RV
426-
# OpFromGraph does not accept constant inputs
427-
non_const_values = [
428-
value
429-
for value in rvs_to_values.values()
430-
if not isinstance(value, (Constant, SharedVariable))
431-
]
432-
joint_logp_op = OpFromGraph([*non_const_values, *inputs], [joint_logp], inline=True)
421+
# Some inputs are not root inputs (such as transformed projections of value variables)
422+
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
423+
inputs = list(inputvars(inputs))
424+
joint_logp_op = OpFromGraph(
425+
list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True
426+
)
433427

434428
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
435-
# each original dimension is independent so that it sufficies to evaluate the graph
429+
# each original dimension is independent so that it suffices to evaluate the graph
436430
# n times, once with each possible value of the marginalized RV replicated across
437431
# batched dimensions of the marginalized RV
438432

@@ -449,18 +443,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
449443
axis2=-1,
450444
)
451445

452-
# OpFromGraph does not accept constant inputs
453-
non_const_values = [
454-
value for value in values if not isinstance(value, (Constant, SharedVariable))
455-
]
456446
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
457447
if len(marginalized_rv_domain) <= 10:
458448
joint_logps = [
459-
joint_logp_op(marginalized_rv_domain_tensor[i], *non_const_values, *inputs)
449+
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
460450
for i in range(len(marginalized_rv_domain))
461451
]
462452
else:
463-
# Make sure this is rewrite is registered
453+
# Make sure this rewrite is registered
464454
from pymc.pytensorf import local_remove_check_parameter
465455

466456
def logp_fn(marginalized_rv_const, *non_sequences):
@@ -469,7 +459,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
469459
joint_logps, _ = scan_map(
470460
fn=logp_fn,
471461
sequences=marginalized_rv_domain_tensor,
472-
non_sequences=[*non_const_values, *inputs],
462+
non_sequences=[*values, *inputs],
473463
mode=Mode().including("local_remove_check_parameter"),
474464
)
475465

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
pymc>=5.2.0
1+
pymc>=5.4.1

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"Programming Language :: Python :: 3.8",
3333
"Programming Language :: Python :: 3.9",
3434
"Programming Language :: Python :: 3.10",
35+
"Programming Language :: Python :: 3.11",
3536
"License :: OSI Approved :: Apache Software License",
3637
"Intended Audience :: Science/Research",
3738
"Topic :: Scientific/Engineering",

0 commit comments

Comments
 (0)