From 10cc350bf819dc2c69d59d0bc66f370a11e07e4c Mon Sep 17 00:00:00 2001 From: deveshbervar Date: Sat, 1 Mar 2025 21:37:53 +0530 Subject: [PATCH] Add type hints to discrete distributions --- pymc/distributions/discrete.py | 58 +++++++++++++++++++++++++----- pymc/distributions/distribution.py | 2 +- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d2f35c8007..ec88c19ee5 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -13,6 +13,8 @@ # limitations under the License. import warnings +from typing import Optional + import numpy as np import pytensor.tensor as pt @@ -118,7 +120,14 @@ class Binomial(Discrete): rv_op = binomial @classmethod - def dist(cls, n, p=None, logit_p=None, *args, **kwargs): + def dist( + cls, + n: DIST_PARAMETER_TYPES, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -234,7 +243,14 @@ def BetaBinom(a, b, n, x): rv_op = betabinom @classmethod - def dist(cls, alpha, beta, n, *args, **kwargs): + def dist( + cls, + alpha: DIST_PARAMETER_TYPES, + beta: DIST_PARAMETER_TYPES, + n: DIST_PARAMETER_TYPES, + *args, + **kwargs, + ): alpha = pt.as_tensor_variable(alpha) beta = pt.as_tensor_variable(beta) n = pt.as_tensor_variable(n, dtype=int) @@ -341,7 +357,13 @@ class Bernoulli(Discrete): rv_op = bernoulli @classmethod - def dist(cls, p=None, logit_p=None, *args, **kwargs): + def dist( + cls, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -465,7 +487,8 @@ def DiscreteWeibull(q, b, x): rv_op = DiscreteWeibullRV.rv_op @classmethod - def dist(cls, q, beta, *args, **kwargs): + def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs): + return super().dist([q, beta], **kwargs) def support_point(rv, size, q, beta): @@ -553,7 +576,8 @@ class Poisson(Discrete): rv_op = poisson @classmethod - def dist(cls, mu, *args, **kwargs): + def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs): + mu = pt.as_tensor_variable(mu) return super().dist([mu], *args, **kwargs) @@ -677,7 +701,16 @@ def NegBinom(a, m, x): rv_op = nbinom @classmethod - def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs): + def dist( + cls, + mu: Optional[DIST_PARAMETER_TYPES] = None, + alpha: Optional[DIST_PARAMETER_TYPES] = None, + p: Optional[DIST_PARAMETER_TYPES] = None, + n: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): + n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n) n = pt.as_tensor_variable(n) p = pt.as_tensor_variable(p) @@ -790,7 +823,8 @@ class Geometric(Discrete): rv_op = geometric @classmethod - def dist(cls, p, *args, **kwargs): + def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs): + p = pt.as_tensor_variable(p) return super().dist([p], *args, **kwargs) @@ -1027,7 +1061,8 @@ class DiscreteUniform(Discrete): rv_op = discrete_uniform @classmethod - def dist(cls, lower, upper, *args, **kwargs): + def dist(cls, lower: DIST_PARAMETER_TYPES, upper: DIST_PARAMETER_TYPES, *args, **kwargs): + lower = pt.floor(lower) upper = pt.floor(upper) return super().dist([lower, upper], **kwargs) @@ -1123,7 +1158,12 @@ class Categorical(Discrete): rv_op = categorical @classmethod - def dist(cls, p=None, logit_p=None, **kwargs): + def dist( + cls, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 5ec5df4671..09d399bea6 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -23,7 +23,7 @@ from functools import singledispatch from typing import Any, TypeAlias -import numpy as np +import numpy as np # type: ignore from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph