Skip to content

Commit 0f3b04e

Browse files
Finish refactoring BlockedStep.competence implementations
1 parent 3b401c1 commit 0f3b04e

File tree

3 files changed

+48
-19
lines changed

3 files changed

+48
-19
lines changed

pymc3/sampling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,14 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
201201
has_gradient = var.dtype not in discrete_types
202202
if has_gradient:
203203
try:
204-
tg.grad(model.logpt, var.tag.value_var)
205-
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
204+
tg.grad(model.logpt, var)
205+
except (NotImplementedError, tg.NullTypeGradError):
206206
has_gradient = False
207207
# select the best method
208+
rv_var = model.values_to_rvs[var]
208209
selected = max(
209210
methods,
210-
key=lambda method, var=var, has_gradient=has_gradient: method._competence(
211+
key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
211212
var, has_gradient
212213
),
213214
)

pymc3/step_methods/metropolis.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from typing import Any, Dict, List, Tuple
1515

1616
import aesara
17-
import aesara.tensor as at
1817
import numpy as np
1918
import numpy.random as nr
2019
import scipy.linalg
2120

22-
from aesara.tensor.random.basic import CategoricalRV
21+
from aesara.graph.fg import MissingInputError
22+
from aesara.tensor.random.basic import BernoulliRV, CategoricalRV
2323

2424
import pymc3 as pm
2525

@@ -362,13 +362,22 @@ def competence(var):
362362
and Categorical variables with k=1.
363363
"""
364364
distribution = getattr(var.owner, "op", None)
365-
if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
366-
return Competence.IDEAL
365+
366+
if isinstance(distribution, BernoulliRV):
367+
return Competence.COMPATIBLE
367368

368369
if isinstance(distribution, CategoricalRV):
369-
k = at.get_scalar_constant_value(distribution.owner.inputs[2])
370-
if k == 2:
371-
return Competence.IDEAL
370+
# TODO: We could compute the initial value of `k`
371+
# if we had a model object.
372+
# k_graph = var.owner.inputs[3].shape[-1]
373+
# (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
374+
# k = model.fn(k_graph)(initial_point)
375+
try:
376+
k = var.owner.inputs[3].shape[-1].eval()
377+
if k == 2:
378+
return Competence.COMPATIBLE
379+
except MissingInputError:
380+
pass
372381
return Competence.INCOMPATIBLE
373382

374383

@@ -449,13 +458,22 @@ def competence(var):
449458
and Categorical variables with k=2.
450459
"""
451460
distribution = getattr(var.owner, "op", None)
452-
if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
461+
462+
if isinstance(distribution, BernoulliRV):
453463
return Competence.IDEAL
454464

455465
if isinstance(distribution, CategoricalRV):
456-
k = at.get_scalar_constant_value(distribution.owner.inputs[2])
457-
if k == 2:
458-
return Competence.IDEAL
466+
# TODO: We could compute the initial value of `k`
467+
# if we had a model object.
468+
# k_graph = var.owner.inputs[3].shape[-1]
469+
# (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
470+
# k = model.fn(k_graph)(initial_point)
471+
try:
472+
k = var.owner.inputs[3].shape[-1].eval()
473+
if k == 2:
474+
return Competence.IDEAL
475+
except MissingInputError:
476+
pass
459477
return Competence.INCOMPATIBLE
460478

461479

@@ -585,13 +603,23 @@ def competence(var):
585603
Categorical variables.
586604
"""
587605
distribution = getattr(var.owner, "op", None)
606+
588607
if isinstance(distribution, CategoricalRV):
589-
k = at.get_scalar_constant_value(distribution.owner.inputs[2])
590-
if k == 2:
591-
return Competence.IDEAL
608+
# TODO: We could compute the initial value of `k`
609+
# if we had a model object.
610+
# k_graph = var.owner.inputs[3].shape[-1]
611+
# (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
612+
# k = model.fn(k_graph)(initial_point)
613+
try:
614+
k = var.owner.inputs[3].shape[-1].eval()
615+
if k > 2:
616+
return Competence.IDEAL
617+
except MissingInputError:
618+
pass
619+
592620
return Competence.COMPATIBLE
593621

594-
if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
622+
if isinstance(distribution, BernoulliRV):
595623
return Competence.COMPATIBLE
596624

597625
return Competence.INCOMPATIBLE

pymc3/step_methods/slicer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def astep(self, q0, logp):
119119
@staticmethod
120120
def competence(var, has_grad):
121121
if var.dtype in continuous_types:
122-
if not has_grad and (var.shape is None or var.shape.ndim == 1):
122+
if not has_grad and var.ndim == 0:
123123
return Competence.PREFERRED
124124
return Competence.COMPATIBLE
125125
return Competence.INCOMPATIBLE

0 commit comments

Comments
 (0)