Skip to content

Commit 20e6e3b

Browse files
committed
[DO NOT MERGE WITHOUT SQUASHING with Tweaks to SymbolicRandomVariables] Get rid of magic
1 parent 6b43938 commit 20e6e3b

File tree

1 file changed

+41
-38
lines changed

1 file changed

+41
-38
lines changed

pymc/distributions/distribution.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,6 @@ def fn(*args, **kwargs):
272272
return fn
273273

274274

275-
class _class_or_instancemethod(classmethod):
276-
"""Allow a method to be called both as a classmethod and an instancemethod,
277-
giving priority to the instancemethod.
278-
279-
This is used to allow extracting information from the signature of a SymbolicRandomVariable
280-
which may be provided either as a class attribute or as an instance attribute.
281-
282-
Adapted from https://stackoverflow.com/a/28238047
283-
"""
284-
285-
def __get__(self, instance, type_):
286-
descr_get = super().__get__ if instance is None else self.__func__.__get__
287-
return descr_get(instance, type_)
288-
289-
290275
class SymbolicRandomVariable(OpFromGraph):
291276
"""Symbolic Random Variable
292277
@@ -343,37 +328,40 @@ def _parse_params_signature(signature):
343328
else:
344329
return _parse_gufunc_signature(params_signature)
345330

346-
@_class_or_instancemethod
347-
@property
348-
def ndims_params(cls_or_self) -> Sequence[int] | None:
331+
@staticmethod
332+
def _ndims_params_from_signature(signature: str) -> Sequence[int]:
349333
"""Number of core dimensions of the distribution's parameters."""
350-
signature = cls_or_self.signature
351-
if signature is None:
352-
return None
353-
inputs_signature, _ = cls_or_self._parse_params_signature(signature)
334+
inputs_signature, _ = SymbolicRandomVariable._parse_params_signature(signature)
354335
return [len(sig) for sig in inputs_signature]
355336

356-
@_class_or_instancemethod
337+
@classmethod
357338
@property
358-
def ndim_supp(cls_or_self) -> int | None:
359-
"""Number of support dimensions of the RandomVariable
360-
361-
(0 for scalar, 1 for vector, ...)
362-
"""
363-
signature = cls_or_self.signature
339+
def ndims_params(cls) -> Sequence[int] | None:
340+
signature = cls.signature
364341
if signature is None:
365342
return None
366-
_, outputs_params_signature = cls_or_self._parse_params_signature(signature)
343+
return cls._ndims_params_from_signature(signature)
344+
345+
@staticmethod
346+
def _ndim_supp_from_signature(signature: str) -> int:
347+
_, outputs_params_signature = SymbolicRandomVariable._parse_params_signature(signature)
367348
return max(len(out_sig) for out_sig in outputs_params_signature)
368349

369-
@_class_or_instancemethod
350+
@classmethod
370351
@property
371-
def default_output(cls_or_self) -> int | None:
372-
signature = cls_or_self.signature
352+
def ndim_supp(cls) -> int | None:
353+
"""Number of support dimensions of the RandomVariable
354+
355+
(0 for scalar, 1 for vector, ...)
356+
"""
357+
signature = cls.signature
373358
if signature is None:
374359
return None
360+
return cls._ndim_supp_from_signature(signature)
375361

376-
_, outputs_signature = cls_or_self._parse_signature(signature)
362+
@staticmethod
363+
def _default_output_from_signature(signature: str) -> int | None:
364+
_, outputs_signature = SymbolicRandomVariable._parse_signature(signature)
377365

378366
# If there is a single non `[rng]` outputs, that is the default one!
379367
candidate_default_output = [
@@ -384,6 +372,14 @@ def default_output(cls_or_self) -> int | None:
384372
else:
385373
return None
386374

375+
@classmethod
376+
@property
377+
def default_output(cls) -> int | None:
378+
signature = cls.signature
379+
if signature is None:
380+
return None
381+
return cls._default_output_from_signature(signature)
382+
387383
@staticmethod
388384
def get_idxs(signature: str) -> tuple[tuple[int], int | None, tuple[int]]:
389385
"""Parse signature and return indexes for *[rng], [size] and parameters"""
@@ -406,10 +402,17 @@ def __init__(
406402
**kwargs,
407403
):
408404
"""Initialize a SymbolicRandomVariable class."""
409-
if "signature" in kwargs:
410-
self.signature = kwargs.pop("signature")
411-
412-
if "ndim_supp" in kwargs:
405+
signature = kwargs.pop("signature", None)
406+
if signature is not None:
407+
self.signature = signature
408+
# Override class properties with instance properties
409+
self.ndims_params = self._ndims_params_from_signature(signature)
410+
self.ndim_supp = self._ndim_supp_from_signature(self.signature)
411+
self.default_output = self._default_output_from_signature(self.signature)
412+
413+
elif "ndim_supp" in kwargs:
414+
# For backwards compatibility we allow passing ndim_supp without signature
415+
# This is the only variable that PyMC absolutely needs to work with SymbolicRandomVariables
413416
self.ndim_supp = kwargs.pop("ndim_supp")
414417

415418
if self.ndim_supp is None:

0 commit comments

Comments
 (0)