Skip to content

Commit e6537e4

Browse files
michaelosthegericardoV94
authored andcommitted
Relocate PointFunc definition for nicer type hints
1 parent a4ace35 commit e6537e4

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

pymc/aesaraf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,16 @@ def reshape_t(x, shape):
677677
return x[0]
678678

679679

680+
class PointFunc:
681+
"""Wraps so a function so it takes a dict of arguments instead of arguments."""
682+
683+
def __init__(self, f):
684+
self.f = f
685+
686+
def __call__(self, state):
687+
return self.f(**state)
688+
689+
680690
class CallableTensor:
681691
"""Turns a symbolic variable with one input into a function that returns symbolic arguments
682692
with the one variable replaced with the input.

pymc/model.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from aesara.tensor.var import TensorConstant, TensorVariable
4949

5050
from pymc.aesaraf import (
51+
PointFunc,
5152
compile_pymc,
5253
convert_observed_data,
5354
gradient,
@@ -640,7 +641,7 @@ def compile_logp(
640641
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
641642
jacobian: bool = True,
642643
sum: bool = True,
643-
):
644+
) -> PointFunc:
644645
"""Compiled log probability density function.
645646
646647
Parameters
@@ -660,7 +661,7 @@ def compile_dlogp(
660661
self,
661662
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
662663
jacobian: bool = True,
663-
):
664+
) -> PointFunc:
664665
"""Compiled log probability density gradient function.
665666
666667
Parameters
@@ -677,7 +678,7 @@ def compile_d2logp(
677678
self,
678679
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
679680
jacobian: bool = True,
680-
):
681+
) -> PointFunc:
681682
"""Compiled log probability density hessian function.
682683
683684
Parameters
@@ -1597,7 +1598,7 @@ def compile_fn(
15971598
mode=None,
15981599
point_fn: bool = True,
15991600
**kwargs,
1600-
) -> Union["PointFunc", Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]:
1601+
) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]:
16011602
"""Compiles an Aesara function
16021603
16031604
Parameters
@@ -1913,16 +1914,6 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
19131914
}
19141915

19151916

1916-
class PointFunc:
1917-
"""Wraps so a function so it takes a dict of arguments instead of arguments."""
1918-
1919-
def __init__(self, f):
1920-
self.f = f
1921-
1922-
def __call__(self, state):
1923-
return self.f(**state)
1924-
1925-
19261917
def Deterministic(name, var, model=None, dims=None, auto=False):
19271918
"""Create a named deterministic variable
19281919

0 commit comments

Comments
 (0)