48
48
from aesara .tensor .var import TensorConstant , TensorVariable
49
49
50
50
from pymc .aesaraf import (
51
+ PointFunc ,
51
52
compile_pymc ,
52
53
convert_observed_data ,
53
54
gradient ,
@@ -640,7 +641,7 @@ def compile_logp(
640
641
vars : Optional [Union [Variable , Sequence [Variable ]]] = None ,
641
642
jacobian : bool = True ,
642
643
sum : bool = True ,
643
- ):
644
+ ) -> PointFunc :
644
645
"""Compiled log probability density function.
645
646
646
647
Parameters
@@ -660,7 +661,7 @@ def compile_dlogp(
660
661
self ,
661
662
vars : Optional [Union [Variable , Sequence [Variable ]]] = None ,
662
663
jacobian : bool = True ,
663
- ):
664
+ ) -> PointFunc :
664
665
"""Compiled log probability density gradient function.
665
666
666
667
Parameters
@@ -677,7 +678,7 @@ def compile_d2logp(
677
678
self ,
678
679
vars : Optional [Union [Variable , Sequence [Variable ]]] = None ,
679
680
jacobian : bool = True ,
680
- ):
681
+ ) -> PointFunc :
681
682
"""Compiled log probability density hessian function.
682
683
683
684
Parameters
@@ -1597,7 +1598,7 @@ def compile_fn(
1597
1598
mode = None ,
1598
1599
point_fn : bool = True ,
1599
1600
** kwargs ,
1600
- ) -> Union [" PointFunc" , Callable [[Sequence [np .ndarray ]], Sequence [np .ndarray ]]]:
1601
+ ) -> Union [PointFunc , Callable [[Sequence [np .ndarray ]], Sequence [np .ndarray ]]]:
1601
1602
"""Compiles an Aesara function
1602
1603
1603
1604
Parameters
@@ -1913,16 +1914,6 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
1913
1914
}
1914
1915
1915
1916
1916
- class PointFunc :
1917
- """Wraps so a function so it takes a dict of arguments instead of arguments."""
1918
-
1919
- def __init__ (self , f ):
1920
- self .f = f
1921
-
1922
- def __call__ (self , state ):
1923
- return self .f (** state )
1924
-
1925
-
1926
1917
def Deterministic (name , var , model = None , dims = None , auto = False ):
1927
1918
"""Create a named deterministic variable
1928
1919
0 commit comments