Skip to content

Commit 80f8195

Browse files
ricardoV94twiecki
authored andcommitted
Rename logp_transform to _get_default_transform and move it to transforms.py
1 parent 02860d3 commit 80f8195

File tree

5 files changed

+16
-15
lines changed

5 files changed

+16
-15
lines changed

pymc/distributions/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pymc.distributions.logprob import ( # isort:skip
1616
logcdf,
1717
logp,
18-
logp_transform,
1918
joint_logpt,
2019
)
2120

@@ -195,6 +194,5 @@
195194
"PolyaGamma",
196195
"joint_logpt",
197196
"logp",
198-
"logp_transform",
199197
"logcdf",
200198
]

pymc/distributions/continuous.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def polyagamma_cdf(*args, **kwargs):
7373
from scipy.special import expit
7474

7575
from pymc.aesaraf import floatX
76-
from pymc.distributions import logp_transform, transforms
76+
from pymc.distributions import transforms
7777
from pymc.distributions.dist_math import (
7878
SplineWrapper,
7979
check_parameters,
@@ -87,6 +87,7 @@ def polyagamma_cdf(*args, **kwargs):
8787
)
8888
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
8989
from pymc.distributions.shape_utils import rv_size_is_none
90+
from pymc.distributions.transforms import _get_default_transform
9091
from pymc.math import invlogit, logdiffexp, logit
9192
from pymc.util import UNSET
9293

@@ -139,17 +140,17 @@ class CircularContinuous(Continuous):
139140
"""Base class for circular continuous distributions"""
140141

141142

142-
@logp_transform.register(PositiveContinuous)
143+
@_get_default_transform.register(PositiveContinuous)
143144
def pos_cont_transform(op):
144145
return transforms.log
145146

146147

147-
@logp_transform.register(UnitContinuous)
148+
@_get_default_transform.register(UnitContinuous)
148149
def unit_cont_transform(op):
149150
return transforms.logodds
150151

151152

152-
@logp_transform.register(CircularContinuous)
153+
@_get_default_transform.register(CircularContinuous)
153154
def circ_cont_transform(op):
154155
return transforms.circular
155156

pymc/distributions/logprob.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from collections.abc import Mapping
16-
from functools import singledispatch
1716
from typing import Dict, List, Optional, Sequence, Union
1817

1918
import aesara
@@ -25,7 +24,6 @@
2524
from aeppl.logprob import logprob as logp_aeppl
2625
from aeppl.transforms import TransformValuesOpt
2726
from aesara.graph.basic import graph_inputs, io_toposort
28-
from aesara.graph.op import Op
2927
from aesara.tensor.subtensor import (
3028
AdvancedIncSubtensor,
3129
AdvancedIncSubtensor1,
@@ -39,11 +37,6 @@
3937
from pymc.aesaraf import floatX
4038

4139

42-
@singledispatch
43-
def logp_transform(op: Op):
44-
return None
45-
46-
4740
def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int):
4841
"""
4942
Gets scaling constant for logp.

pymc/distributions/transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from functools import singledispatch
1415

1516
import aesara.tensor as at
1617
import numpy as np
@@ -23,6 +24,7 @@
2324
RVTransform,
2425
Simplex,
2526
)
27+
from aesara.graph import Op
2628

2729
__all__ = [
2830
"RVTransform",
@@ -39,6 +41,12 @@
3941
]
4042

4143

44+
@singledispatch
45+
def _get_default_transform(op: Op):
46+
"""Return default transform for a given Distribution `Op`"""
47+
return None
48+
49+
4250
class LogExpM1(RVTransform):
4351
name = "log_exp_m1"
4452

pymc/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@
5858
)
5959
from pymc.blocking import DictToArrayBijection, RaveledVars
6060
from pymc.data import GenTensorVariable, Minibatch
61-
from pymc.distributions import joint_logpt, logp_transform
61+
from pymc.distributions import joint_logpt
6262
from pymc.distributions.logprob import _get_scaling
63+
from pymc.distributions.transforms import _get_default_transform
6364
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
6465
from pymc.initial_point import make_initial_point_fn
6566
from pymc.math import flatten_list
@@ -1421,7 +1422,7 @@ def create_value_var(
14211422
# Make the value variable a transformed value variable,
14221423
# if there's an applicable transform
14231424
if transform is UNSET and rv_var.owner:
1424-
transform = logp_transform(rv_var.owner.op)
1425+
transform = _get_default_transform(rv_var.owner.op)
14251426

14261427
if transform is not None and transform is not UNSET:
14271428
value_var.tag.transform = transform

0 commit comments

Comments
 (0)