@@ -1603,10 +1603,13 @@ def compile_fn(
1603
1603
1604
1604
Parameters
1605
1605
----------
1606
- outs: Aesara variable or iterable of Aesara variables
1607
- inputs: Aesara input variables, defaults to aesaraf.inputvars(outs).
1608
- mode: Aesara compilation mode, default=None
1609
- point_fn:
1606
+ outs
1607
+ Aesara variable or iterable of Aesara variables.
1608
+ inputs
1609
+ Aesara input variables, defaults to aesaraf.inputvars(outs).
1610
+ mode
1611
+ Aesara compilation mode, default=None.
1612
+ point_fn : bool
1610
1613
Whether to wrap the compiled function in a PointFunc, which takes a Point
1611
1614
dictionary with model variable names and values as input.
1612
1615
@@ -1872,16 +1875,24 @@ def set_data(new_data, model=None, *, coords=None):
1872
1875
model .set_data (variable_name , new_value , coords = coords )
1873
1876
1874
1877
1875
- def compile_fn (outs , mode = None , point_fn = True , model = None , ** kwargs ):
1878
+ def compile_fn (
1879
+ outs , mode = None , point_fn : bool = True , model : Optional [Model ] = None , ** kwargs
1880
+ ) -> Union [PointFunc , Callable [[Sequence [np .ndarray ]], Sequence [np .ndarray ]]]:
1876
1881
"""Compiles an Aesara function which returns ``outs`` and takes values of model
1877
1882
vars as a dict as an argument.
1883
+
1878
1884
Parameters
1879
1885
----------
1880
- outs: Aesara variable or iterable of Aesara variables
1881
- mode: Aesara compilation mode
1882
- point_fn:
1886
+ outs
1887
+ Aesara variable or iterable of Aesara variables.
1888
+ mode
1889
+ Aesara compilation mode, default=None.
1890
+ point_fn : bool
1883
1891
Whether to wrap the compiled function in a PointFunc, which takes a Point
1884
1892
dictionary with model variable names and values as input.
1893
+ model : Model, optional
1894
+ Current model on stack.
1895
+
1885
1896
Returns
1886
1897
-------
1887
1898
Compiled Aesara function as point function.
0 commit comments