|
14 | 14 | from typing import Any, Dict, List, Tuple
|
15 | 15 |
|
16 | 16 | import aesara
|
17 |
| -import aesara.tensor as at |
18 | 17 | import numpy as np
|
19 | 18 | import numpy.random as nr
|
20 | 19 | import scipy.linalg
|
21 | 20 |
|
22 |
| -from aesara.tensor.random.basic import CategoricalRV |
| 21 | +from aesara.graph.fg import MissingInputError |
| 22 | +from aesara.tensor.random.basic import BernoulliRV, CategoricalRV |
23 | 23 |
|
24 | 24 | import pymc3 as pm
|
25 | 25 |
|
@@ -362,13 +362,22 @@ def competence(var):
|
362 | 362 | and Categorical variables with k=1.
|
363 | 363 | """
|
364 | 364 | distribution = getattr(var.owner, "op", None)
|
365 |
| - if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types): |
366 |
| - return Competence.IDEAL |
| 365 | + |
| 366 | + if isinstance(distribution, BernoulliRV): |
| 367 | + return Competence.COMPATIBLE |
367 | 368 |
|
368 | 369 | if isinstance(distribution, CategoricalRV):
|
369 |
| - k = at.get_scalar_constant_value(distribution.owner.inputs[2]) |
370 |
| - if k == 2: |
371 |
| - return Competence.IDEAL |
| 370 | + # TODO: We could compute the initial value of `k` |
| 371 | + # if we had a model object. |
| 372 | + # k_graph = var.owner.inputs[3].shape[-1] |
| 373 | + # (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True) |
| 374 | + # k = model.fn(k_graph)(initial_point) |
| 375 | + try: |
| 376 | + k = var.owner.inputs[3].shape[-1].eval() |
| 377 | + if k == 2: |
| 378 | + return Competence.COMPATIBLE |
| 379 | + except MissingInputError: |
| 380 | + pass |
372 | 381 | return Competence.INCOMPATIBLE
|
373 | 382 |
|
374 | 383 |
|
@@ -449,13 +458,22 @@ def competence(var):
|
449 | 458 | and Categorical variables with k=2.
|
450 | 459 | """
|
451 | 460 | distribution = getattr(var.owner, "op", None)
|
452 |
| - if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types): |
| 461 | + |
| 462 | + if isinstance(distribution, BernoulliRV): |
453 | 463 | return Competence.IDEAL
|
454 | 464 |
|
455 | 465 | if isinstance(distribution, CategoricalRV):
|
456 |
| - k = at.get_scalar_constant_value(distribution.owner.inputs[2]) |
457 |
| - if k == 2: |
458 |
| - return Competence.IDEAL |
| 466 | + # TODO: We could compute the initial value of `k` |
| 467 | + # if we had a model object. |
| 468 | + # k_graph = var.owner.inputs[3].shape[-1] |
| 469 | + # (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True) |
| 470 | + # k = model.fn(k_graph)(initial_point) |
| 471 | + try: |
| 472 | + k = var.owner.inputs[3].shape[-1].eval() |
| 473 | + if k == 2: |
| 474 | + return Competence.IDEAL |
| 475 | + except MissingInputError: |
| 476 | + pass |
459 | 477 | return Competence.INCOMPATIBLE
|
460 | 478 |
|
461 | 479 |
|
@@ -585,13 +603,23 @@ def competence(var):
|
585 | 603 | Categorical variables.
|
586 | 604 | """
|
587 | 605 | distribution = getattr(var.owner, "op", None)
|
| 606 | + |
588 | 607 | if isinstance(distribution, CategoricalRV):
|
589 |
| - k = at.get_scalar_constant_value(distribution.owner.inputs[2]) |
590 |
| - if k == 2: |
591 |
| - return Competence.IDEAL |
| 608 | + # TODO: We could compute the initial value of `k` |
| 609 | + # if we had a model object. |
| 610 | + # k_graph = var.owner.inputs[3].shape[-1] |
| 611 | + # (k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True) |
| 612 | + # k = model.fn(k_graph)(initial_point) |
| 613 | + try: |
| 614 | + k = var.owner.inputs[3].shape[-1].eval() |
| 615 | + if k > 2: |
| 616 | + return Competence.IDEAL |
| 617 | + except MissingInputError: |
| 618 | + pass |
| 619 | + |
592 | 620 | return Competence.COMPATIBLE
|
593 | 621 |
|
594 |
| - if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types): |
| 622 | + if isinstance(distribution, BernoulliRV): |
595 | 623 | return Competence.COMPATIBLE
|
596 | 624 |
|
597 | 625 | return Competence.INCOMPATIBLE
|
|
0 commit comments