Skip to content

Commit d9f0775

Browse files
committed
Rename transform_args to bound_args_indices
1 parent fa4bdb8 commit d9f0775

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pymc3/distributions/continuous.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def circ_cont_transform(op):
134134
class BoundedContinuous(Continuous):
135135
"""Base class for bounded continuous distributions"""
136136

137-
transform_args = None
137+
# Indices of the arguments that define the lower and upper bounds of the distribution
138+
bound_args_indices = None
138139

139140
def __new__(cls, *args, **kwargs):
140141
transform = kwargs.get("transform", UNSET)
@@ -144,13 +145,15 @@ def __new__(cls, *args, **kwargs):
144145

145146
@classmethod
146147
def default_transform(cls):
147-
if cls.transform_args is None:
148-
raise ValueError(f"Must specify transform args for {cls.__name__} bounded distribution")
148+
if cls.bound_args_indices is None:
149+
raise ValueError(
150+
f"Must specify bound_args_indices for {cls.__name__} bounded distribution"
151+
)
149152

150153
def transform_params(rv_var):
151154
_, _, _, *args = rv_var.owner.inputs
152-
lower = args[cls.transform_args[0]]
153-
upper = args[cls.transform_args[1]]
155+
lower = args[cls.bound_args_indices[0]]
156+
upper = args[cls.bound_args_indices[1]]
154157
lower = at.as_tensor_variable(lower) if lower is not None else None
155158
upper = at.as_tensor_variable(upper) if upper is not None else None
156159
return lower, upper
@@ -244,7 +247,7 @@ class Uniform(BoundedContinuous):
244247
Upper limit.
245248
"""
246249
rv_op = uniform
247-
transform_args = [0, 1] # Lower, Upper
250+
bound_args_indices = (0, 1) # Lower, Upper
248251

249252
@classmethod
250253
def dist(cls, lower=0, upper=1, **kwargs):
@@ -3339,7 +3342,7 @@ class Triangular(BoundedContinuous):
33393342
"""
33403343

33413344
rv_op = triangular
3342-
transform_args = [0, 2] # lower, upper
3345+
bound_args_indices = (0, 2) # lower, upper
33433346

33443347
@classmethod
33453348
def dist(cls, lower=0, upper=1, c=0.5, *args, **kwargs):

0 commit comments

Comments
 (0)