+{% endif %}
\ No newline at end of file
diff --git a/doc/.templates/rendered_citation.html b/doc/.templates/rendered_citation.html
new file mode 100644
index 0000000000..ccb53efa6f
--- /dev/null
+++ b/doc/.templates/rendered_citation.html
@@ -0,0 +1,13 @@
+
+{% if pagename in ablog %}
+ {% set post = ablog[pagename] %}
+ {% for coll in post.author %}
+ {% if coll|length %}
+ {{ coll }}
+ {% if loop.index < post.author | length %},{% endif %}
+ {% else %}
+ {{ coll }}
+ {% if loop.index < post.author | length %},{% endif %}
+ {% endif %}
+ {% endfor %}. "{{ title.split(' — ')[0] }}". In: Pytensor Examples. Ed. by Pytensor Team.
+{% endif %}
\ No newline at end of file
diff --git a/doc/blog.md b/doc/blog.md
new file mode 100644
index 0000000000..88ebe9dc5b
--- /dev/null
+++ b/doc/blog.md
@@ -0,0 +1,7 @@
+---
+orphan: true
+---
+
+# Recent updates
+
+
diff --git a/doc/conf.py b/doc/conf.py
index 5b2d0c71a4..1729efc4b1 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -1,31 +1,13 @@
-# pytensor documentation build configuration file, created by
-# sphinx-quickstart on Tue Oct 7 16:34:06 2008.
-#
-# This file is execfile()d with the current directory set to its containing
-# directory.
-#
-# The contents of this file are pickled, so don't put values in the namespace
-# that aren't pickleable (module imports are okay, they're removed
-# automatically).
-#
-# All configuration values have a default value; values that are commented out
-# serve to show the default value.
-
-# If your extensions are in another directory, add it here. If the directory
-# is relative to the documentation root, use Path.absolute to make it
-# absolute, like shown here.
-# sys.path.append(str(Path("some/directory").absolute()))
-
import os
import inspect
import sys
import pytensor
+from pathlib import Path
+
+sys.path.insert(0, str(Path("..").resolve() / "scripts"))
# General configuration
# ---------------------
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.todo",
@@ -34,9 +16,22 @@
"sphinx.ext.linkcode",
"sphinx.ext.mathjax",
"sphinx_design",
- "sphinx.ext.intersphinx"
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.autosectionlabel",
+ "ablog",
+ "myst_nb",
+ "generate_gallery",
+ "sphinx_sitemap",
]
+# Don't auto-generate summary for class members.
+numpydoc_show_class_members = False
+autosummary_generate = True
+autodoc_typehints = "none"
+remove_from_toctrees = ["**/classmethods/*"]
+
+
intersphinx_mapping = {
"jax": ("https://jax.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
@@ -92,6 +87,7 @@
# List of directories, relative to source directories, that shouldn't be
# searched for source files.
exclude_dirs = ["images", "scripts", "sandbox"]
+exclude_patterns = ['page_footer.md', '**/*.myst.md']
# The reST default role (used for this markup: `text`) to use for all
# documents.
@@ -115,19 +111,15 @@
# Options for HTML output
# -----------------------
-# The style sheet to use for HTML and HTML Help pages. A file of that name
-# must exist either in Sphinx' static/ path, or in one of the custom paths
-# given in html_static_path.
-# html_style = 'default.css'
-# html_theme = 'sphinxdoc'
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+html_theme = "pymc_sphinx_theme"
+html_logo = "images/PyTensor_RGB.svg"
+
+html_baseurl = "https://pytensor.readthedocs.io"
+sitemap_url_scheme = f"{{lang}}{rtd_version}/{{link}}"
-# html4_writer added to Fix colon & whitespace misalignment
-# https://github.com/readthedocs/sphinx_rtd_theme/issues/766#issuecomment-513852197
-# https://github.com/readthedocs/sphinx_rtd_theme/issues/766#issuecomment-629666319
-# html4_writer = False
-html_logo = "images/PyTensor_RGB.svg"
-html_theme = "pymc_sphinx_theme"
html_theme_options = {
"use_search_override": False,
"icon_links": [
@@ -156,15 +148,27 @@
"type": "fontawesome",
},
],
+ "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink", "donate"],
+ "navbar_start": ["navbar-logo"],
+ "article_header_end": ["nb-badges"],
+ "article_footer_items": ["rendered_citation.html"],
}
html_context = {
+ "github_url": "https://github.com",
"github_user": "pymc-devs",
"github_repo": "pytensor",
- "github_version": "main",
+ "github_version": version if "." in rtd_version else "main",
+ "sandbox_repo": f"pymc-devs/pymc-sandbox/{version}",
"doc_path": "doc",
"default_mode": "light",
}
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+# html_static_path = ["../_static"]
+html_extra_path = ["_thumbnails", 'images', "robots.txt"]
+templates_path = [".templates"]
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
@@ -295,3 +299,62 @@ def find_source():
# If false, no module index is generated.
# latex_use_modindex = True
+
+
+# -- MyST config -------------------------------------------------
+myst_enable_extensions = [
+ "colon_fence",
+ "deflist",
+ "dollarmath",
+ "amsmath",
+ "substitution",
+]
+myst_dmath_double_inline = True
+
+citation_code = f"""
+```bibtex
+@incollection{{citekey,
+ author = "",
+ title = "",
+ editor = "Pytensor Team",
+ booktitle = "Pytensor Examples",
+}}
+```
+"""
+
+myst_substitutions = {
+ "pip_dependencies": "{{ extra_dependencies }}",
+ "conda_dependencies": "{{ extra_dependencies }}",
+ "extra_install_notes": "",
+ "citation_code": citation_code,
+}
+
+nb_execution_mode = "off"
+nbsphinx_execute = "never"
+nbsphinx_allow_errors = True
+
+rediraffe_redirects = {
+ "index.md": "gallery.md",
+}
+
+# -- Bibtex config -------------------------------------------------
+bibtex_bibfiles = ["references.bib"]
+bibtex_default_style = "unsrt"
+bibtex_reference_style = "author_year"
+
+
+# -- ablog config -------------------------------------------------
+blog_baseurl = "https://pytensor.readthedocs.io/en/latest/index.html"
+blog_title = "Pytensor Examples"
+blog_path = "blog"
+blog_authors = {
+ "contributors": ("Pytensor Contributors", "https://pytensor.readthedocs.io"),
+}
+blog_default_author = "contributors"
+post_show_prev_next = False
+fontawesome_included = True
+# post_redirect_refresh = 1
+# post_auto_image = 1
+# post_auto_excerpt = 2
+
+# notfound_urls_prefix = ""
diff --git a/doc/core_development_guide.rst b/doc/core_development_guide.rst
index 082fbaa514..82c15ddc8f 100644
--- a/doc/core_development_guide.rst
+++ b/doc/core_development_guide.rst
@@ -26,12 +26,4 @@ some of them might be outdated though:
* :ref:`unittest` -- Tutorial on how to use unittest in testing PyTensor.
-* :ref:`sandbox_debugging_step_mode` -- How to step through the execution of
- an PyTensor function and print the inputs and outputs of each op.
-
-* :ref:`sandbox_elemwise` -- Description of element wise operations.
-
-* :ref:`sandbox_randnb` -- Description of how PyTensor deals with random
- numbers.
-
* :ref:`sparse` -- Description of the ``sparse`` type in PyTensor.
diff --git a/doc/environment.yml b/doc/environment.yml
index ae17b6379d..d58af79cc6 100644
--- a/doc/environment.yml
+++ b/doc/environment.yml
@@ -14,6 +14,14 @@ dependencies:
- pillow
- pymc-sphinx-theme
- sphinx-design
+ - pygments
+ - pydot
+ - ipython
+ - myst-nb
+ - matplotlib
+ - watermark
+ - ablog
- pip
- pip:
+ - sphinx_sitemap
- -e ..
diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst
index 8be08b4953..1fb25f83b6 100644
--- a/doc/extending/creating_a_numba_jax_op.rst
+++ b/doc/extending/creating_a_numba_jax_op.rst
@@ -1,5 +1,5 @@
Adding JAX, Numba and Pytorch support for `Op`\s
-=======================================
+================================================
PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do
this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function.
@@ -7,7 +7,7 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Py
This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.
Step 1: Identify the PyTensor :class:`Op` you'd like to implement
-------------------------------------------------------------------------
+-----------------------------------------------------------------
Find the source for the PyTensor :class:`Op` you'd like to be supported and
identify the function signature and return values. These can be determined by
@@ -98,7 +98,7 @@ how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented.
Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close)
----------------------------------------------------------
+--------------------------------------------------------------------------
With a precise idea of what the PyTensor :class:`Op` does we need to figure out how
to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named
@@ -269,7 +269,7 @@ and :func:`torch.cumprod`
z[0] = np.cumprod(x, axis=self.axis)
Step 3: Register the function with the respective dispatcher
----------------------------------------------------------------
+------------------------------------------------------------
With the PyTensor `Op` replicated, we'll need to register the
function with the backends `Linker`. This is done through the use of
diff --git a/doc/gallery/page_footer.md b/doc/gallery/page_footer.md
new file mode 100644
index 0000000000..6f9c88f801
--- /dev/null
+++ b/doc/gallery/page_footer.md
@@ -0,0 +1,27 @@
+## License notice
+All the notebooks in this example gallery are provided under a
+[3-Clause BSD License](https://github.com/pymc-devs/pytensor/blob/main/doc/LICENSE.txt)
+which allows modification, and redistribution for any
+use provided the copyright and license notices are preserved.
+
+## Citing Pytensor Examples
+
+To cite this notebook, please use the suggested citation below.
+
+:::{important}
+Many notebooks are adapted from other sources: blogs, books... In such cases you should
+cite the original source as well.
+
+Also remember to cite the relevant libraries used by your code.
+:::
+
+Here is an example citation template in bibtex:
+
+{{ citation_code }}
+
+which once rendered could look like:
+
+
+
\ No newline at end of file
diff --git a/doc/gallery/rewrites/graph_rewrites.ipynb b/doc/gallery/rewrites/graph_rewrites.ipynb
new file mode 100644
index 0000000000..298e13b95e
--- /dev/null
+++ b/doc/gallery/rewrites/graph_rewrites.ipynb
@@ -0,0 +1,1104 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(Graph_rewrites)=\n",
+ "\n",
+ "# PyTensor graph rewrites from scratch\n",
+ "\n",
+ ":::{post} Jan 11, 2025 \n",
+ ":tags: Graph rewrites \n",
+ ":category: avanced, explanation \n",
+ ":author: Ricardo Vieira \n",
+ ":::"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Manipulating nodes directly"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section walks through the low level details of PyTensor graph manipulation. \n",
+ "Users are not supposed to work or even be aware of these details, but it may be helpful for developers.\n",
+ "We start with very **bad practices** and move on towards the **right** way of doing rewrites.\n",
+ "\n",
+ "* {doc}`Graph structures `\n",
+ "is a required precursor to this guide\n",
+ "* {doc}`Graph rewriting ` provides the user-level summary of what is covered in here. Feel free to revisit once you're done here.\n",
+ "\n",
+ "As described in {doc}`Graph structures`, PyTensor graphs are composed of sequences {class}`Apply` nodes, which link {class}`Variable`s\n",
+ "that form the inputs and outputs of a computational {class}`Op`eration.\n",
+ "\n",
+ "The list of inputs of an {class}`Apply` node can be changed inplace to modify the computational path that leads to it.\n",
+ "Consider the following simple example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:46.104335Z",
+ "start_time": "2025-01-11T07:37:46.100021Z"
+ }
+ },
+ "source": [
+ "%env PYTENSOR_FLAGS=cxx=\"\""
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "env: PYTENSOR_FLAGS=cxx=\"\"\n"
+ ]
+ }
+ ],
+ "execution_count": 1
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:49.384149Z",
+ "start_time": "2025-01-11T07:37:46.201672Z"
+ }
+ },
+ "source": [
+ "import pytensor\n",
+ "import pytensor.tensor as pt\n",
+ "\n",
+ "x = pt.scalar(\"x\")\n",
+ "y = pt.log(1 + x)\n",
+ "out = y * 2\n",
+ "pytensor.dprint(out, id_type=\"\");"
+ ],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mul\n",
+ " ├─ Log\n",
+ " │ └─ Add\n",
+ " │ ├─ 1\n",
+ " │ └─ x\n",
+ " └─ 2\n"
+ ]
+ }
+ ],
+ "execution_count": 2
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "A standard rewrite replaces `pt.log(1 + x)` by the more stable form `pt.log1p(x)`.\n",
+ "We can do this by changing the inputs of the `out` node inplace."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:49.924153Z",
+ "start_time": "2025-01-11T07:37:49.920272Z"
+ }
+ },
+ "source": [
+ "out.owner.inputs[0] = pt.log1p(x)\n",
+ "pytensor.dprint(out, id_type=\"\");"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mul\n",
+ " ├─ Log1p\n",
+ " │ └─ x\n",
+ " └─ 2\n"
+ ]
+ }
+ ],
+ "execution_count": 3
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There are two problems with this direct approach:\n",
+ "1. We are modifying variables in place\n",
+ "2. We have to know which nodes have as input the variable we want to replace\n",
+ "\n",
+ "Point 1. is important because some rewrites are \"destructive\" and the user may want to reuse the same graph in multiple functions.\n",
+ "\n",
+ "Point 2. is important because it forces us to shift the focus of attention from the operation we want to rewrite to the variables where the operation is used. It also risks unneccessary duplication of variables, if we perform the same replacement independently for each use. This could make graph rewriting consideraby slower!\n",
+ "\n",
+ "PyTensor makes use of {class}`FunctionGraph`s to solve these two issues.\n",
+ "By default, a FunctionGraph will clone all the variables between the inputs and outputs,\n",
+ "so that the corresponding graph can be rewritten.\n",
+ "In addition, it will create a {term}`client`s dictionary that maps all the variables to the nodes where they are used.\n",
+ "\n",
+ "\n",
+ "Let's see how we can use a FunctionGraph to achieve the same rewrite:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.005393Z",
+ "start_time": "2025-01-11T07:37:49.997328Z"
+ }
+ },
+ "source": [
+ "from pytensor.graph import FunctionGraph\n",
+ "\n",
+ "x = pt.scalar(\"x\")\n",
+ "y = pt.log(1 + x)\n",
+ "out1 = y * 2\n",
+ "out2 = 2 / y\n",
+ "\n",
+ "# Create an empty dictionary which FunctionGraph will populate\n",
+ "# with the mappings from old variables to cloned ones\n",
+ "memo = {}\n",
+ "fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)\n",
+ "fg_x = memo[x]\n",
+ "fg_y = memo[y]\n",
+ "print(\"Before:\\n\")\n",
+ "pytensor.dprint(fg.outputs)\n",
+ "\n",
+ "# Create expression of interest with cloned variables\n",
+ "fg_y_repl = pt.log1p(fg_x)\n",
+ "\n",
+ "# Update all uses of old variable to new one\n",
+ "# Each entry in the clients dictionary, \n",
+ "# contains a node and the input index where the variable is used\n",
+ "# Note: Some variables could be used multiple times in a single node\n",
+ "for client, idx in fg.clients[fg_y]:\n",
+ " client.inputs[idx] = fg_y_repl\n",
+ " \n",
+ "print(\"\\nAfter:\\n\")\n",
+ "pytensor.dprint(fg.outputs);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Before:\n",
+ "\n",
+ "Mul [id A]\n",
+ " ├─ Log [id B]\n",
+ " │ └─ Add [id C]\n",
+ " │ ├─ 1 [id D]\n",
+ " │ └─ x [id E]\n",
+ " └─ 2 [id F]\n",
+ "True_div [id G]\n",
+ " ├─ 2 [id H]\n",
+ " └─ Log [id B]\n",
+ " └─ ···\n",
+ "\n",
+ "After:\n",
+ "\n",
+ "Mul [id A]\n",
+ " ├─ Log1p [id B]\n",
+ " │ └─ x [id C]\n",
+ " └─ 2 [id D]\n",
+ "True_div [id E]\n",
+ " ├─ 2 [id F]\n",
+ " └─ Log1p [id B]\n",
+ " └─ ···\n"
+ ]
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can see that both uses of `log(1 + x)` were replaced by the new `log1p(x)`.\n",
+ "\n",
+ "It would probably be a good idea to update the clients dictionary\n",
+ "if we wanted to perform another rewrite.\n",
+ "\n",
+ "There are a couple of other variables in the FunctionGraph that we would also want to update,\n",
+ "but there is no point to doing all this bookeeping manually. \n",
+ "FunctionGraph offers a {meth}`replace ` method that takes care of all this for the user."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.078947Z",
+ "start_time": "2025-01-11T07:37:50.072465Z"
+ }
+ },
+ "source": [
+ "# We didn't modify the variables in place so we can just reuse them!\n",
+ "memo = {}\n",
+ "fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)\n",
+ "fg_x = memo[x]\n",
+ "fg_y = memo[y]\n",
+ "print(\"Before:\\n\")\n",
+ "pytensor.dprint(fg.outputs)\n",
+ "\n",
+ "# Create expression of interest with cloned variables\n",
+ "fg_y_repl = pt.log1p(fg_x)\n",
+ "fg.replace(fg_y, fg_y_repl)\n",
+ " \n",
+ "print(\"\\nAfter:\\n\")\n",
+ "pytensor.dprint(fg.outputs);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Before:\n",
+ "\n",
+ "Mul [id A]\n",
+ " ├─ Log [id B]\n",
+ " │ └─ Add [id C]\n",
+ " │ ├─ 1 [id D]\n",
+ " │ └─ x [id E]\n",
+ " └─ 2 [id F]\n",
+ "True_div [id G]\n",
+ " ├─ 2 [id H]\n",
+ " └─ Log [id B]\n",
+ " └─ ···\n",
+ "\n",
+ "After:\n",
+ "\n",
+ "Mul [id A]\n",
+ " ├─ Log1p [id B]\n",
+ " │ └─ x [id C]\n",
+ " └─ 2 [id D]\n",
+ "True_div [id E]\n",
+ " ├─ 2 [id F]\n",
+ " └─ Log1p [id B]\n",
+ " └─ ···\n"
+ ]
+ }
+ ],
+ "execution_count": 5
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There is still one big limitation with this approach.\n",
+ "We have to know in advance \"where\" the variable we want to replace is present.\n",
+ "It also doesn't scale to multiple instances of the same pattern.\n",
+ "\n",
+ "A more sensible approach would be to iterate over the nodes in the FunctionGraph\n",
+ "and apply the rewrite wherever `log(1 + x)` may be present.\n",
+ "\n",
+ "To keep code organized we will create a function \n",
+ "that takes as input a node and returns a valid replacement."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.161507Z",
+ "start_time": "2025-01-11T07:37:50.156975Z"
+ }
+ },
+ "source": [
+ "from pytensor.graph import Constant\n",
+ "\n",
+ "def local_log1p(node):\n",
+ " # Check that this node is a Log op\n",
+ " if node.op != pt.log:\n",
+ " return None\n",
+ " \n",
+ " # Check that the input is another node (it could be an input variable)\n",
+ " add_node = node.inputs[0].owner\n",
+ " if add_node is None:\n",
+ " return None\n",
+ " \n",
+ " # Check that the input to this node is an Add op\n",
+ " # with 2 inputs (Add can have more inputs)\n",
+ " if add_node.op != pt.add or len(add_node.inputs) != 2:\n",
+ " return None\n",
+ " \n",
+ " # Check wether we have add(1, y) or add(x, 1)\n",
+ " [x, y] = add_node.inputs\n",
+ " if isinstance(x, Constant) and x.data == 1:\n",
+ " return [pt.log1p(y)]\n",
+ " if isinstance(y, Constant) and y.data == 1:\n",
+ " return [pt.log1p(x)]\n",
+ "\n",
+ " return None"
+ ],
+ "outputs": [],
+ "execution_count": 6
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.248106Z",
+ "start_time": "2025-01-11T07:37:50.242014Z"
+ }
+ },
+ "source": [
+ "# We no longer need the memo, because our rewrite works with the node information\n",
+ "fg = FunctionGraph([x], [out1, out2], clone=True)\n",
+ "\n",
+ "# Toposort gives a list of all nodes in a graph in topological order\n",
+ "# The strategy of iteration can be important when we are dealing with multiple rewrites\n",
+ "for node in fg.toposort():\n",
+ " repl = local_log1p(node)\n",
+ " if repl is None:\n",
+ " continue\n",
+ " # We should get one replacement of each output of the node\n",
+ " assert len(repl) == len(node.outputs)\n",
+ " # We could use `fg.replace_all` to avoid this loop\n",
+ " for old, new in zip(node.outputs, repl):\n",
+ " fg.replace(old, new)\n",
+ "\n",
+ "pytensor.dprint(fg);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mul [id A] 1\n",
+ " ├─ Log1p [id B] 0\n",
+ " │ └─ x [id C]\n",
+ " └─ 2 [id D]\n",
+ "True_div [id E] 2\n",
+ " ├─ 2 [id F]\n",
+ " └─ Log1p [id B] 0\n",
+ " └─ ···\n"
+ ]
+ }
+ ],
+ "execution_count": 7
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is starting to look much more scalable!\n",
+ "\n",
+ "We are still reinventing may wheels that already exist in PyTensor, but we're getting there.\n",
+ "Before we move up the ladder of abstraction, let's discuss two gotchas:\n",
+ "\n",
+ "1. The replacement variables should have types that are compatible with the original ones.\n",
+ "2. We have to be careful about introducing circular dependencies\n",
+ "\n",
+ "For 1. let's look at a simple graph simplification, where we replace a costly operation that is ultimately multiplied by zero."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.344446Z",
+ "start_time": "2025-01-11T07:37:50.328071Z"
+ }
+ },
+ "source": [
+ "x = pt.vector(\"x\", dtype=\"float32\")\n",
+ "zero = pt.zeros(())\n",
+ "zero.name = \"zero\"\n",
+ "y = pt.exp(x) * zero\n",
+ "\n",
+ "fg = FunctionGraph([x], [y], clone=False)\n",
+ "try:\n",
+ " fg.replace(y, pt.zeros(()))\n",
+ "except TypeError as exc:\n",
+ " print(f\"TypeError: {exc}\")"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable Alloc.0) into Type Vector(float64, shape=(?,)). You can try to manually convert Alloc.0 into a Vector(float64, shape=(?,)).\n"
+ ]
+ }
+ ],
+ "execution_count": 8
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The first achievement of a new PyTensor developer is unlocked by stumbling upon an error like that!\n",
+ "\n",
+ "It's important to keep in mind the Tensor part of PyTensor.\n",
+ "\n",
+ "The problem here is that we are trying to replace the `y` variable which is a float32 vector by the `zero` variable which is a float64 scalar!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.408682Z",
+ "start_time": "2025-01-11T07:37:50.404355Z"
+ }
+ },
+ "source": [
+ "pytensor.dprint(fg.outputs, id_type=\"\", print_type=True);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mul \n",
+ " ├─ Exp \n",
+ " │ └─ x \n",
+ " └─ ExpandDims{axis=0} \n",
+ " └─ Alloc 'zero'\n",
+ " └─ 0.0 \n"
+ ]
+ }
+ ],
+ "execution_count": 9
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.512585Z",
+ "start_time": "2025-01-11T07:37:50.488176Z"
+ }
+ },
+ "source": [
+ "vector_zero = pt.zeros(x.shape)\n",
+ "vector_zero.name = \"vector_zero\"\n",
+ "fg.replace(y, vector_zero)\n",
+ "pytensor.dprint(fg.outputs, id_type=\"\", print_type=True);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Alloc 'vector_zero'\n",
+ " ├─ 0.0 \n",
+ " └─ Subtensor{i} \n",
+ " ├─ Shape \n",
+ " │ └─ x \n",
+ " └─ 0 \n"
+ ]
+ }
+ ],
+ "execution_count": 10
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now to the second (less common) gotcha. Introducing circular dependencies:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.572844Z",
+ "start_time": "2025-01-11T07:37:50.567175Z"
+ }
+ },
+ "source": [
+ "x = pt.scalar(\"x\")\n",
+ "y = x + 1\n",
+ "y.name = \"y\"\n",
+ "z = y + 1\n",
+ "z.name = \"z\"\n",
+ "\n",
+ "fg = FunctionGraph([x], [z], clone=False)\n",
+ "fg.replace(x, z)\n",
+ "pytensor.dprint(fg.outputs);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Add [id A] 'z'\n",
+ " ├─ Add [id B] 'y'\n",
+ " │ ├─ Add [id A] 'z'\n",
+ " │ │ └─ ···\n",
+ " │ └─ 1 [id C]\n",
+ " └─ 1 [id D]\n"
+ ]
+ }
+ ],
+ "execution_count": 11
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Oops! There is not much to say about this one, other than don't do it!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Using graph rewriters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.634996Z",
+ "start_time": "2025-01-11T07:37:50.631699Z"
+ }
+ },
+ "source": [
+ "from pytensor.graph.rewriting.basic import NodeRewriter\n",
+ "\n",
+ "class LocalLog1pNodeRewriter(NodeRewriter):\n",
+ " \n",
+ " def tracks(self):\n",
+ " return [pt.log]\n",
+ " \n",
+ " def transform(self, fgraph, node):\n",
+ " return local_log1p(node) \n",
+ " \n",
+ " def __str__(self):\n",
+ " return \"local_log1p\"\n",
+ " \n",
+ " \n",
+ "local_log1p_node_rewriter = LocalLog1pNodeRewriter()"
+ ],
+ "outputs": [],
+ "execution_count": 12
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "A {class}`NodeRewriter` is required to implement only the {meth}`transform ` method.\n",
+ "As before, this method expects a node and should return a valid replacement for each output or `None`.\n",
+ "\n",
+ "We also receive the {class}`FunctionGraph` object, as some node rewriters may want to use global information to decide whether to return a replacement or not.\n",
+ "\n",
+ "For example some rewrites that skip intermediate computations may not be useful if those intermediate computations are used by other variables.\n",
+ "\n",
+ "The {meth}`tracks ` optional method is very useful for filtering out \"useless\" rewrites. When {class}`NodeRewriter`s only applies to a specific rare {class}`Op` it can be ignored completely when that {class}`Op` is not present in the graph.\n",
+ "\n",
+ "On its own, a {class}`NodeRewriter` isn't any better than what we had before. Where it becomes useful is when included inside a {class}`GraphRewriter`, which will apply it to a whole {class}`FunctionGraph `."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.702188Z",
+ "start_time": "2025-01-11T07:37:50.696179Z"
+ }
+ },
+ "source": [
+ "from pytensor.graph.rewriting.basic import in2out\n",
+ "\n",
+ "x = pt.scalar(\"x\")\n",
+ "y = pt.log(1 + x)\n",
+ "out = pt.exp(y)\n",
+ "\n",
+ "fg = FunctionGraph([x], [out])\n",
+ "in2out(local_log1p_node_rewriter, name=\"local_log1p\").rewrite(fg)\n",
+ "\n",
+ "pytensor.dprint(fg.outputs);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Exp [id A]\n",
+ " └─ Log1p [id B]\n",
+ " └─ x [id C]\n"
+ ]
+ }
+ ],
+ "execution_count": 13
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we used {func}`in2out` which creates a {class}`GraphRewriter` (specifically a {class}`WalkingGraphRewriter`) which walks from the inputs to the outputs of a FunctionGraph trying to apply whatever nodes are \"registered\" in it.\n",
+ "\n",
+ "Wrapping simple functions in {class}`NodeRewriter`s is so common that PyTensor offers a decorator for it.\n",
+ "\n",
+ "Let's create a new rewrite that removes useless `abs(exp(x)) -> exp(x)`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.761196Z",
+ "start_time": "2025-01-11T07:37:50.757401Z"
+ }
+ },
+ "source": [
+ "from pytensor.graph.rewriting.basic import node_rewriter\n",
+ "\n",
+ "@node_rewriter(tracks=[pt.abs])\n",
+ "def local_useless_abs_exp(fgraph, node):\n",
+ " # Because of the tracks we don't need to check \n",
+ " # that `node` has a `Sign` Op.\n",
+ " # We still need to check whether it's input is an `Abs` Op\n",
+ " exp_node = node.inputs[0].owner\n",
+ " if exp_node is None or exp_node.op != pt.exp:\n",
+ " return None\n",
+ " return exp_node.outputs"
+ ],
+ "outputs": [],
+ "execution_count": 14
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "Another very useful helper is the {class}`PatternNodeRewriter`, which allows you to specify a rewrite via \"template matching\"."
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.848713Z",
+ "start_time": "2025-01-11T07:37:50.845435Z"
+ }
+ },
+ "source": [
+ "from pytensor.graph.rewriting.basic import PatternNodeRewriter\n",
+ "\n",
+ "local_useless_abs_square = PatternNodeRewriter(\n",
+ " (pt.abs, (pt.pow, \"x\", 2)),\n",
+ " (pt.pow, \"x\", 2),\n",
+ " name=\"local_useless_abs_square\",\n",
+ ")"
+ ],
+ "outputs": [],
+ "execution_count": 15
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is very useful for simple Elemwise rewrites, but becomes a bit cumbersome with Ops that must be parametrized\n",
+ "everytime they are used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.925407Z",
+ "start_time": "2025-01-11T07:37:50.897320Z"
+ }
+ },
+ "source": [
+ "x = pt.scalar(\"x\")\n",
+ "y = pt.exp(x)\n",
+ "z = pt.abs(y)\n",
+ "w = pt.log(1.0 + z)\n",
+ "out = pt.abs(w ** 2)\n",
+ "\n",
+ "fg = FunctionGraph([x], [out])\n",
+ "in2out_rewrite = in2out(\n",
+ " local_log1p_node_rewriter, \n",
+ " local_useless_abs_exp, \n",
+ " local_useless_abs_square,\n",
+ " name=\"custom_rewrites\"\n",
+ ")\n",
+ "in2out_rewrite.rewrite(fg)\n",
+ "\n",
+ "pytensor.dprint(fg.outputs);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Pow [id A]\n",
+ " ├─ Log1p [id B]\n",
+ " │ └─ Exp [id C]\n",
+ " │ └─ x [id D]\n",
+ " └─ 2 [id E]\n"
+ ]
+ }
+ ],
+ "execution_count": 16
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Besides {class}`WalkingGraphRewriter`s, there are:\n",
+ " - {class}`SequentialGraphRewriter`s, which apply a set of {class}`GraphRewriters` sequentially \n",
+ " - {class}`EquilibriumGraphRewriter`s which apply a set of {class}`GraphRewriters` (and {class}`NodeRewriters`) repeatedly until the graph stops changing.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Registering graph rewriters in a database"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, at the top of the rewrite mountain, there are {class}`RewriteDatabase`s! These allow \"querying\" for subsets of rewrites registered in a database.\n",
+ "\n",
+ "Most users trigger this when they change the `mode` of a PyTensor function `mode=\"FAST_COMPILE\"` or `mode=\"FAST_RUN\"`, or `mode=\"JAX\"` will lead to a different rewrite database query to be applied to the function before compilation.\n",
+ "\n",
+ "The most relevant {class}`RewriteDatabase` is called `optdb` and contains all the standard rewrites in PyTensor. You can manually register your {class}`GraphRewriter` in it. \n",
+ "\n",
+ "More often than not, you will want to register your rewrite in a pre-existing sub-database, like {term}`canonicalize`, {term}`stabilize`, or {term}`specialize`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:50.979283Z",
+ "start_time": "2025-01-11T07:37:50.976168Z"
+ }
+ },
+ "source": [
+ "from pytensor.compile.mode import optdb"
+ ],
+ "outputs": [],
+ "execution_count": 17
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.032996Z",
+ "start_time": "2025-01-11T07:37:51.029510Z"
+ }
+ },
+ "source": [
+ "optdb[\"canonicalize\"].register(\n",
+ " \"local_log1p_node_rewriter\",\n",
+ " local_log1p_node_rewriter,\n",
+ " \"fast_compile\",\n",
+ " \"fast_run\",\n",
+ " \"custom\",\n",
+ ")"
+ ],
+ "outputs": [],
+ "execution_count": 18
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.156080Z",
+ "start_time": "2025-01-11T07:37:51.095154Z"
+ }
+ },
+ "source": [
+ "with pytensor.config.change_flags(optimizer_verbose = True):\n",
+ " fn = pytensor.function([x], out, mode=\"FAST_COMPILE\")\n",
+ " \n",
+ "print(\"\")\n",
+ "pytensor.dprint(fn);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)\n",
+ "\n",
+ "Abs [id A] 4\n",
+ " └─ Pow [id B] 3\n",
+ " ├─ Log1p [id C] 2\n",
+ " │ └─ Abs [id D] 1\n",
+ " │ └─ Exp [id E] 0\n",
+ " │ └─ x [id F]\n",
+ " └─ 2 [id G]\n"
+ ]
+ }
+ ],
+ "execution_count": 19
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "There's also a decorator, {func}`register_canonicalize`, that automatically registers a {class}`NodeRewriter` in one of these standard databases. (It's placed in a weird location)"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.220260Z",
+ "start_time": "2025-01-11T07:37:51.216259Z"
+ }
+ },
+ "source": [
+ "from pytensor.tensor.rewriting.basic import register_canonicalize\n",
+ "\n",
+ "@register_canonicalize(\"custom\")\n",
+ "@node_rewriter(tracks=[pt.abs])\n",
+ "def local_useless_abs_exp(fgraph, node):\n",
+ " # Because of the tracks we don't need to check \n",
+ " # that `node` has a `Sign` Op.\n",
+ " # We still need to check whether it's input is an `Abs` Op\n",
+ " exp_node = node.inputs[0].owner\n",
+ " if exp_node is None or exp_node.op != pt.exp:\n",
+ " return None\n",
+ " return exp_node.outputs"
+ ],
+ "outputs": [],
+ "execution_count": 20
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And you can also use the decorator directly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.292003Z",
+ "start_time": "2025-01-11T07:37:51.286043Z"
+ }
+ },
+ "source": [
+ "register_canonicalize(local_useless_abs_square, \"custom\")"
+ ],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "local_useless_abs_square"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 21
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.380138Z",
+ "start_time": "2025-01-11T07:37:51.362056Z"
+ }
+ },
+ "source": [
+ "with pytensor.config.change_flags(optimizer_verbose = True):\n",
+ " fn = pytensor.function([x], out, mode=\"FAST_COMPILE\")\n",
+ " \n",
+ "print(\"\")\n",
+ "pytensor.dprint(fn);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rewriting: rewrite local_useless_abs_square replaces Abs.0 of Abs(Pow.0) with Pow.0 of Pow(Log.0, 2)\n",
+ "rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)\n",
+ "rewriting: rewrite local_useless_abs_exp replaces Abs.0 of Abs(Exp.0) with Exp.0 of Exp(x)\n",
+ "\n",
+ "Pow [id A] 2\n",
+ " ├─ Log1p [id B] 1\n",
+ " │ └─ Exp [id C] 0\n",
+ " │ └─ x [id D]\n",
+ " └─ 2 [id E]\n"
+ ]
+ }
+ ],
+ "execution_count": 22
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And if you wanted to exclude your custom rewrites you can do it like this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.487102Z",
+ "start_time": "2025-01-11T07:37:51.459955Z"
+ }
+ },
+ "source": [
+ "from pytensor.compile.mode import get_mode\n",
+ "\n",
+ "with pytensor.config.change_flags(optimizer_verbose = True):\n",
+ " fn = pytensor.function([x], out, mode=get_mode(\"FAST_COMPILE\").excluding(\"custom\"))\n",
+ " \n",
+ "print(\"\")\n",
+ "pytensor.dprint(fn);"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rewriting: rewrite local_upcast_elemwise_constant_inputs replaces Add.0 of Add(1.0, Abs.0) with Add.0 of Add(Cast{float64}.0, Abs.0)\n",
+ "rewriting: rewrite constant_folding replaces Cast{float64}.0 of Cast{float64}(1.0) with 1.0 of None\n",
+ "\n",
+ "Abs [id A] 5\n",
+ " └─ Pow [id B] 4\n",
+ " ├─ Log [id C] 3\n",
+ " │ └─ Add [id D] 2\n",
+ " │ ├─ 1.0 [id E]\n",
+ " │ └─ Abs [id F] 1\n",
+ " │ └─ Exp [id G] 0\n",
+ " │ └─ x [id H]\n",
+ " └─ 2 [id I]\n"
+ ]
+ }
+ ],
+ "execution_count": 23
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "## Authors\n",
+ "\n",
+ "- Authored by Ricardo Vieira in May 2023"
+ ]
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "## References\n",
+ "\n",
+ ":::{bibliography} :filter: docname in docnames"
+ ]
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Watermark "
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:37:51.621272Z",
+ "start_time": "2025-01-11T07:37:51.580753Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%load_ext watermark\n",
+ "%watermark -n -u -v -iv -w -p pytensor"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Last updated: Sat Jan 11 2025\n",
+ "\n",
+ "Python implementation: CPython\n",
+ "Python version : 3.12.0\n",
+ "IPython version : 8.31.0\n",
+ "\n",
+ "pytensor: 2.26.4+16.g8be5c5323.dirty\n",
+ "\n",
+ "sys : 3.12.0 | packaged by conda-forge | (main, Oct 3 2023, 08:43:22) [GCC 12.3.0]\n",
+ "pytensor: 2.26.4+16.g8be5c5323.dirty\n",
+ "\n",
+ "Watermark: 2.5.0\n",
+ "\n"
+ ]
+ }
+ ],
+ "execution_count": 24
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ ":::{include} ../page_footer.md \n",
+ ":::"
+ ]
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": ""
+ }
+ ],
+ "metadata": {
+ "hide_input": false,
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/doc/gallery/scan/scan_tutorial.ipynb b/doc/gallery/scan/scan_tutorial.ipynb
new file mode 100644
index 0000000000..3428698450
--- /dev/null
+++ b/doc/gallery/scan/scan_tutorial.ipynb
@@ -0,0 +1,852 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(Scan_tutorial)=\n",
+ "# Introduction to Scan\n",
+ ":::{post} Jan 11, 2025 \n",
+ ":tags: scan, worked examples, tutorial\n",
+ ":category: beginner, explanation \n",
+ ":author: Pascal Lamblin, Jesse Grabowski\n",
+ ":::\n",
+ "\n",
+ "A Pytensor function graph is composed of two types of nodes: Variable nodes which represent data, and Apply node which apply Ops (which represent some computation) to Variables to produce new Variables.\n",
+ "\n",
+ "From this point of view, a node that applies a Scan Op is just like any other. Internally, however, it is very different from most Ops.\n",
+ "\n",
+ "Inside a Scan op is yet another Pytensor graph which represents the computation to be performed at every iteration of the loop. During compilation, that graph is compiled into a function. During execution, the Scan Op will call that function repeatedly on its inputs to produce its outputs.\n",
+ "\n",
+ "## Examples\n",
+ "\n",
+ "Scan's interface is complex and, thus, best introduced by examples. \n"
+ ]
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "### Example 1: As Simple as it Gets\n",
+ "So, let's dive right in and start with a simple example; perform an element-wise multiplication between two vectors. \n",
+ "\n",
+ "This particular example is simple enough that Scan is not the best way to do things but we'll gradually work our way to more complex examples where Scan gets more interesting.\n",
+ "\n",
+ "Let's first setup our use case by defining Pytensor variables for the inputs :"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:39:58.951346Z",
+ "start_time": "2025-01-10T17:39:53.088554Z"
+ }
+ },
+ "source": [
+ "import pytensor\n",
+ "import pytensor.tensor as pt\n",
+ "import numpy as np\n",
+ "\n",
+ "vector1 = pt.dvector('vector1')\n",
+ "vector2 = pt.dvector('vector2')"
+ ],
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we call the `scan` function. It has many parameters but, because our use case is simple, we only need two of them. We'll introduce other parameters in the next examples.\n",
+ "\n",
+ "The parameter `sequences` allows us to specify variables that Scan should iterate over as it loops. The first iteration will take as input the first element of every sequence, the second iteration will take as input the second element of every sequence, etc. These individual element have will have one less dimension than the original sequences. For example, for a matrix sequence, the individual elements will be vectors.\n",
+ "\n",
+ "The parameter `fn` receives a function or lambda expression that expresses the computation to do at every iteration. It operates on the symbolic inputs to produce symbolic outputs. It will **only ever be called once**, to assemble the Pytensor graph used by Scan at every the iterations.\n",
+ "\n",
+ "Since we wish to iterate over both `vector1` and `vector2` simultaneously, we provide them as sequences. This means that every iteration will operate on two inputs: an element from `vector1` and the corresponding element from `vector2`. \n",
+ "\n",
+ "Because what we want is the elementwise product between the vectors, we provide a lambda expression that takes an element `a` from `vector1` and an element `b` from `vector2` then computes and return the product."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:39:59.004407Z",
+ "start_time": "2025-01-10T17:39:58.955818Z"
+ }
+ },
+ "source": [
+ "output, updates = pytensor.scan(fn=lambda a, b : a * b,\n",
+ " sequences=[vector1, vector2])"
+ ],
+ "outputs": [],
+ "execution_count": 2
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Calling `scan`, we see that it returns two outputs.\n",
+ "\n",
+ "The first output contains the outputs of `fn` from every timestep concatenated into a tensor. In our case, the output of a single timestep is a scalar so output is a vector where `output[i]` is the output of the i-th iteration.\n",
+ "\n",
+ "The second output details if and how the execution of the `Scan` updates any shared variable in the graph. It should be provided as an argument when compiling the Pytensor function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "scrolled": true,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.081533Z",
+ "start_time": "2025-01-10T17:39:59.741663Z"
+ }
+ },
+ "source": [
+ "f = pytensor.function(inputs=[vector1, vector2],\n",
+ " outputs=output,\n",
+ " updates=updates)"
+ ],
+ "outputs": [],
+ "execution_count": 3
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If `updates` is omitted, the state of any shared variables modified by `Scan` will not be updated properly. Random number sampling, for instance, relies on shared variables. If `updates` is not provided, the state of the random number generator won't be updated properly and the same numbers might be sampled repeatedly. **Always** provide `updates` when compiling your Pytensor function, unless you are sure that you don't need it!\n",
+ "\n",
+ "Now that we've defined how to do elementwise multiplication with Scan, we can see that the result is as expected :"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.128785Z",
+ "start_time": "2025-01-10T17:40:00.125260Z"
+ }
+ },
+ "source": [
+ "floatX = pytensor.config.floatX\n",
+ "\n",
+ "vector1_value = np.arange(0, 5).astype(floatX) # [0,1,2,3,4]\n",
+ "vector2_value = np.arange(1, 6).astype(floatX) # [1,2,3,4,5]\n",
+ "print(f(vector1_value, vector2_value))"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 0. 2. 6. 12. 20.]\n"
+ ]
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "An interesting thing is that we never explicitly told Scan how many iteration it needed to run. It was automatically inferred; when given sequences, Scan will run as many iterations as the length of the shortest sequence. Here we just truncate one of the sequences to 4 elements, and we get only 4 outputs."
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.199150Z",
+ "start_time": "2025-01-10T17:40:00.195450Z"
+ }
+ },
+ "source": [
+ "print(f(vector1_value, vector2_value[:4]))"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 0. 2. 6. 12.]\n"
+ ]
+ }
+ ],
+ "execution_count": 5
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Example 2: Non-sequences\n",
+ "\n",
+ "In this example, we introduce another of Scan's features; non-sequences. To demonstrate how to use them, we use Scan to compute the activations of a linear MLP layer over a minibatch.\n",
+ "\n",
+ "It is not yet a use case where Scan is truly useful but it introduces a requirement that sequences cannot fulfill; if we want to use Scan to iterate over the minibatch elements and compute the activations for each of them, then we need some variables (the parameters of the layer), to be available 'as is' at every iteration of the loop. We do *not* want Scan to iterate over them and give only part of them at every iteration.\n",
+ "\n",
+ "Once again, we begin by setting up our Pytensor variables :"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.263086Z",
+ "start_time": "2025-01-10T17:40:00.259308Z"
+ }
+ },
+ "source": [
+ "X = pt.dmatrix('X') # Minibatch of data\n",
+ "W = pt.dmatrix('W') # Weights of the layer\n",
+ "b = pt.dvector('b') # Biases of the layer"
+ ],
+ "outputs": [],
+ "execution_count": 6
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For the sake of variety, in this example we define the computation to be done at every iteration of the loop using a Python function, `step()`, instead of a lambda expression.\n",
+ "\n",
+ "To have the full weight matrix W and the full bias vector b available at every iteration, we use the argument `non_sequences`. Contrary to `sequences`, `non_sequences` are not iterated upon by Scan. Every non-sequence is passed as input to every iteration.\n",
+ "\n",
+ "This means that our `step()` function will need to operate on three symbolic inputs; one for our sequence X and one for each of our non-sequences W and b. \n",
+ "\n",
+ "The inputs that correspond to the non-sequences are **always** last and in the same order at the non-sequences are provided to Scan. This means that the correspondence between the inputs of the `step()` function and the arguments to `scan()` is the following : \n",
+ "\n",
+ "* `v` : individual element of the sequence `X` \n",
+ "* `W` and `b` : non-sequences `W` and `b`, respectively"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.366395Z",
+ "start_time": "2025-01-10T17:40:00.316085Z"
+ }
+ },
+ "source": [
+ "def step(v, W, b):\n",
+ " return v @ W + b\n",
+ "\n",
+ "output, updates = pytensor.scan(fn=step,\n",
+ " sequences=[X],\n",
+ " non_sequences=[W, b])\n",
+ "print(updates)"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{}\n"
+ ]
+ }
+ ],
+ "execution_count": 7
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "We can now compile our Pytensor function and see that it gives the expected results."
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.666677Z",
+ "start_time": "2025-01-10T17:40:00.403399Z"
+ }
+ },
+ "source": [
+ "f = pytensor.function(inputs=[X, W, b],\n",
+ " outputs=output,\n",
+ " updates=updates)\n",
+ "\n",
+ "X_value = np.arange(-3, 3).reshape(3, 2).astype(floatX)\n",
+ "W_value = np.eye(2).astype(floatX)\n",
+ "b_value = np.arange(2).astype(floatX)\n",
+ "\n",
+ "print(f(X_value, W_value, b_value))"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[-3. -1.]\n",
+ " [-1. 1.]\n",
+ " [ 1. 3.]]\n"
+ ]
+ }
+ ],
+ "execution_count": 8
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Example 3 : Reusing outputs from the previous iterations\n",
+ "\n",
+ "In this example, we will use Scan to compute a cumulative sum over the first dimension of a matrix $M$. This means that the output will be a matrix $S$ in which the first row will be equal to the first row of $M$, the second row will be equal to the sum of the two first rows of $M$, and so on.\n",
+ "\n",
+ "Another way to express this, which is the way we will implement here, is that $S_t = S_{t-1} + M_t$. Implementing this with Scan would involve iterating over the rows of the matrix $M$ and, at every iteration, reuse the cumulative row that was output at the previous iteration and return the sum of it and the current row of $M$.\n",
+ "\n",
+ "If we assume for a moment that we can get Scan to provide the output value from the previous iteration as an input for every iteration, implementing a step function is simple :"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.698967Z",
+ "start_time": "2025-01-10T17:40:00.695951Z"
+ }
+ },
+ "source": [
+ "def step(m_row, cumulative_sum):\n",
+ " return m_row + cumulative_sum"
+ ],
+ "outputs": [],
+ "execution_count": 9
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The trick part is informing Scan that our step function expects as input the output of a previous iteration. To achieve this, we need to use a new parameter of the `scan()` function: `outputs_info`. This parameter is used to tell Scan how we intend to use each of the outputs that are computed at each iteration.\n",
+ "\n",
+ "This parameter can be omitted (like we did so far) when the step function doesn't depend on any output of a previous iteration. However, now that we wish to have recurrent outputs, we need to start using it.\n",
+ "\n",
+ "`outputs_info` takes a sequence with one element for every output of the `step()` function :\n",
+ "* For a **non-recurrent output** (like in every example before this one), the element should be `None`.\n",
+ "* For a **simple recurrent output** (iteration $t$ depends on the value at iteration $t-1$), the element must be a tensor. Scan will interpret it as being an initial state for a recurrent output and give it as input to the first iteration, pretending it is the output value from a previous iteration. For subsequent iterations, Scan will automatically handle giving the previous output value as an input.\n",
+ "\n",
+ "The `step()` function needs to expect one additional input for each simple recurrent output. These inputs correspond to outputs from previous iteration and are **always** after the inputs that correspond to sequences but before those that correspond to non-sequences. The are received by the `step()` function in the order in which the recurrent outputs are declared in the outputs_info sequence."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.767156Z",
+ "start_time": "2025-01-10T17:40:00.740203Z"
+ }
+ },
+ "source": [
+ "M = pt.dmatrix('X')\n",
+ "s = pt.dvector('s') # Initial value for the cumulative sum\n",
+ "\n",
+ "output, updates = pytensor.scan(fn=step,\n",
+ " sequences=[M],\n",
+ " outputs_info=[s])"
+ ],
+ "outputs": [],
+ "execution_count": 10
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "We can now compile and test the Pytensor function :"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.933590Z",
+ "start_time": "2025-01-10T17:40:00.814705Z"
+ }
+ },
+ "source": [
+ "f = pytensor.function(inputs=[M, s],\n",
+ " outputs=output,\n",
+ " updates=updates)\n",
+ "\n",
+ "M_value = np.arange(9).reshape(3, 3).astype(floatX)\n",
+ "s_value = np.zeros((3, ), dtype=floatX)\n",
+ "\n",
+ "print(f(M_value, s_value))"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[ 0. 1. 2.]\n",
+ " [ 3. 5. 7.]\n",
+ " [ 9. 12. 15.]]\n"
+ ]
+ }
+ ],
+ "execution_count": 11
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "An important thing to notice here, is that the output computed by the Scan does **not** include the initial state that we provided. It only outputs the states that it has computed itself.\n",
+ "\n",
+ "If we want to have both the initial state and the computed states in the same Pytensor variable, we have to join them ourselves."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Example 4 : Reusing outputs from multiple past iterations\n",
+ "\n",
+ "The Fibonacci sequence is a sequence of numbers F where the two first numbers both 1 and every subsequence number is defined as such : $F_n = F_{n-1} + F_{n-2}$. Thus, the Fibonacci sequence goes : 1, 1, 2, 3, 5, 8, 13, ...\n",
+ "\n",
+ "In this example, we will cover how to compute part of the Fibonacci sequence using Scan. Most of the tools required to achieve this have been introduced in the previous examples. The only one missing is the ability to use, at iteration $i$, outputs from iterations older than $i-1$.\n",
+ "\n",
+ "Also, since every example so far had only one output at every iteration of the loop, we will also compute, at each timestep, the ratio between the new term of the Fibonacci sequence and the previous term.\n",
+ "\n",
+ "Writing an appropriate step function given two inputs, representing the two previous terms of the Fibonacci sequence, is easy:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:00.960658Z",
+ "start_time": "2025-01-10T17:40:00.956657Z"
+ }
+ },
+ "source": [
+ "def step(f_minus2, f_minus1):\n",
+ " new_f = f_minus2 + f_minus1\n",
+ " ratio = new_f / f_minus1\n",
+ " return new_f, ratio"
+ ],
+ "outputs": [],
+ "execution_count": 12
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The next step is defining the value of `outputs_info`.\n",
+ "\n",
+ "Recall that, for **non-recurrent outputs**, the value is `None` and, for **simple recurrent outputs**, the value is a single initial state. For **general recurrent outputs**, where iteration $t$ may depend on multiple past values, the value is a dictionary. That dictionary has two values:\n",
+ "* taps : list declaring which previous values of that output every iteration will need. `[-3, -2, -1]` would mean every iteration should take as input the last 3 values of that output. `[-2]` would mean every iteration should take as input the value of that output from two iterations ago.\n",
+ "* initial : tensor of initial values. If every initial value has $n$ dimensions, `initial` will be a single tensor of $n+1$ dimensions with as many initial values as the oldest requested tap. In the case of the Fibonacci sequence, the individual initial values are scalars so the `initial` will be a vector. \n",
+ "\n",
+ "In our example, we have two outputs. The first output is the next computed term of the Fibonacci sequence so every iteration should take as input the two last values of that output. The second output is the ratio between successive terms and we don't reuse its value so this output is non-recurrent. We define the value of `outputs_info` as such :"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:01.023497Z",
+ "start_time": "2025-01-10T17:40:01.019867Z"
+ }
+ },
+ "source": [
+ "f_init = pt.fvector()\n",
+ "outputs_info = [dict(initial=f_init, taps=[-2, -1]),\n",
+ " None]"
+ ],
+ "outputs": [],
+ "execution_count": 13
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we've defined the step function and the properties of our outputs, we can call the `scan()` function. Because the `step()` function has multiple outputs, the first output of `scan()` function will be a list of tensors: the first tensor containing all the states of the first output and the second tensor containing all the states of the second input.\n",
+ "\n",
+ "In every previous example, we used sequences and Scan automatically inferred the number of iterations it needed to run from the length of these\n",
+ "sequences. Now that we have no sequence, we need to explicitly tell Scan how many iterations to run using the `n_step` parameter. The value can be real or symbolic."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:01.080129Z",
+ "start_time": "2025-01-10T17:40:01.069348Z"
+ }
+ },
+ "source": [
+ "output, updates = pytensor.scan(fn=step,\n",
+ " outputs_info=outputs_info,\n",
+ " n_steps=10)\n",
+ "\n",
+ "next_fibonacci_terms = output[0]\n",
+ "ratios_between_terms = output[1]"
+ ],
+ "outputs": [],
+ "execution_count": 14
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "Let's compile our Pytensor function which will take a vector of consecutive values from the Fibonacci sequence and compute the next 10 values :"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:01.254196Z",
+ "start_time": "2025-01-10T17:40:01.134565Z"
+ }
+ },
+ "source": [
+ "f = pytensor.function(inputs=[f_init],\n",
+ " outputs=[next_fibonacci_terms, ratios_between_terms],\n",
+ " updates=updates)\n",
+ "\n",
+ "out = f([1, 1])\n",
+ "print(out[0])\n",
+ "print(out[1])"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 2. 3. 5. 8. 13. 21. 34. 55. 89. 144.]\n",
+ "[2. 1.5 1.6666666 1.6 1.625 1.6153846 1.6190476\n",
+ " 1.617647 1.6181818 1.6179775]\n"
+ ]
+ }
+ ],
+ "execution_count": 15
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## Order of inputs \n",
+ "\n",
+ "When we start using many sequences, recurrent outputs and non-sequences, it's easy to get confused regarding the order in which the step function receives the corresponding inputs. Below is the full order:\n",
+ "\n",
+ "* Element from the first sequence\n",
+ "* ...\n",
+ "* Element from the last sequence\n",
+ "* First requested tap from first recurrent output\n",
+ "* ...\n",
+ "* Last requested tap from first recurrent output\n",
+ "* ...\n",
+ "* First requested tap from last recurrent output\n",
+ "* ...\n",
+ "* Last requested tap from last recurrent output\n",
+ "* First non-sequence\n",
+ "* ...\n",
+ "* Last non-sequence"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## When to use Scan \n",
+ "\n",
+ "Scan is not appropriate for every problem. Here's some information to help you figure out if Scan is the best solution for a given use case.\n",
+ "\n",
+ "### Execution speed\n",
+ "\n",
+ "Using Scan in a Pytensor function typically makes it slightly slower compared to the equivalent Pytensor graph in which the loop is unrolled. Both of these approaches tend to be much slower than a vectorized implementation in which large chunks of the computation can be done in parallel.\n",
+ "\n",
+ "### Compilation speed\n",
+ "\n",
+ "Scan also adds an overhead to the compilation, potentially making it slower, but using it can also dramatically reduce the size of your graph, making compilation much faster. In the end, the effect of Scan on compilation speed will heavily depend on the size of the graph with and without Scan.\n",
+ "\n",
+ "The compilation speed of a Pytensor function using Scan will usually be comparable to one in which the loop is unrolled if the number of iterations is small. It the number of iterations is large, however, the compilation will usually be much faster with Scan.\n",
+ "\n",
+ "### In summary\n",
+ "\n",
+ "If you have one of the following cases, Scan can help :\n",
+ "* A vectorized implementation is not possible (due to the nature of the computation and/or memory usage)\n",
+ "* You want to do a large or variable number of iterations\n",
+ "\n",
+ "If you have one of the following cases, you should consider other options :\n",
+ "* A vectorized implementation could perform the same computation => Use the vectorized approach. It will often be faster during both compilation and execution.\n",
+ "* You want to do a small, fixed, number of iterations (ex: 2 or 3) => It's probably better to simply unroll the computation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Exercises\n",
+ "\n",
+ "### Exercise 1 - Computing a polynomial\n",
+ "\n",
+ "In this exercise, the initial version already works. It computes the value of a polynomial ($n_0 + n_1 x + n_2 x^2 + ... $) of at most 10000 degrees given the coefficients of the various terms and the value of x.\n",
+ "\n",
+ "You must modify it such that the reduction (the sum() call) is done by Scan."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:01.466495Z",
+ "start_time": "2025-01-10T17:40:01.288716Z"
+ }
+ },
+ "source": [
+ "coefficients = pt.dvector(\"coefficients\")\n",
+ "x = pt.dscalar(\"x\")\n",
+ "max_coefficients_supported = 10000\n",
+ "\n",
+ "def step(coeff, power, free_var):\n",
+ " return coeff * free_var ** power\n",
+ "\n",
+ "# Generate the components of the polynomial\n",
+ "full_range = pt.arange(max_coefficients_supported)\n",
+ "components, updates = pytensor.scan(fn=step,\n",
+ " outputs_info=None,\n",
+ " sequences=[coefficients, full_range],\n",
+ " non_sequences=x)\n",
+ "\n",
+ "polynomial = components.sum()\n",
+ "calculate_polynomial = pytensor.function(inputs=[coefficients, x],\n",
+ " outputs=polynomial,\n",
+ " updates=updates)\n",
+ "\n",
+ "test_coeff = np.asarray([1, 0, 2], dtype=floatX)\n",
+ "print(calculate_polynomial(test_coeff, 3))\n",
+ "# 19.0"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19.0\n"
+ ]
+ }
+ ],
+ "execution_count": 16
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Solution** : run the cell below to display the solution to this exercise."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Exercise 2 - Sampling without replacement\n",
+ "\n",
+ "In this exercise, the goal is to implement a Pytensor function that :\n",
+ "* takes as input a vector of probabilities and a scalar\n",
+ "* performs sampling without replacements from those probabilities as many times as the value of the scalar\n",
+ "* returns a vector containing the indices of the sampled elements.\n",
+ "\n",
+ "Partial code is provided to help with the sampling of random numbers since this is not something that was covered in this tutorial."
+ ]
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:01.513298Z",
+ "start_time": "2025-01-10T17:40:01.482238Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "rng = pytensor.shared(np.random.default_rng(1234))\n",
+ "p_vec = pt.dvector(\"p_vec\")\n",
+ "next_rng, onehot_sample = pt.random.multinomial(n=1, p=p_vec, rng=rng).owner.outputs\n",
+ "f = pytensor.function([p_vec], onehot_sample, updates={rng:next_rng})"
+ ],
+ "outputs": [],
+ "execution_count": 17
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2025-01-10T17:40:01.703547Z",
+ "start_time": "2025-01-10T17:40:01.536499Z"
+ }
+ },
+ "source": [
+ "def sample_from_pvect(p, rng):\n",
+ " \"\"\" Provided utility function: given a symbolic vector of\n",
+ " probabilities (which MUST sum to 1), sample one element\n",
+ " and return its index.\n",
+ " \"\"\"\n",
+ " next_rng, onehot_sample = pt.random.multinomial(n=1, p=p, rng=rng).owner.outputs\n",
+ " idx = onehot_sample.argmax()\n",
+ " \n",
+ " return idx, {rng: next_rng}\n",
+ "\n",
+ "def set_p_to_zero(p, i):\n",
+ " \"\"\" Provided utility function: given a symbolic vector of\n",
+ " probabilities and an index 'i', set the probability of the\n",
+ " i-th element to 0 and renormalize the probabilities so they\n",
+ " sum to 1.\n",
+ " \"\"\"\n",
+ " new_p = p[i].set(0.)\n",
+ " new_p = new_p / new_p.sum()\n",
+ " return new_p\n",
+ "\n",
+ "def sample(p, rng):\n",
+ " idx, updates = sample_from_pvect(p, rng)\n",
+ " p = set_p_to_zero(p, idx)\n",
+ " return (p, idx), updates\n",
+ "\n",
+ "probabilities = pt.dvector()\n",
+ "nb_samples = pt.iscalar()\n",
+ "\n",
+ "SEED = sum(map(ord, 'PyTensor Scan'))\n",
+ "rng = pytensor.shared(np.random.default_rng(SEED))\n",
+ "\n",
+ "\n",
+ "# TODO use Scan to sample from the vector of probabilities and\n",
+ "# symbolically obtain 'samples' the vector of sampled indices.\n",
+ "[probs, samples], updates = pytensor.scan(fn=sample,\n",
+ " outputs_info=[probabilities, None],\n",
+ " non_sequences=[rng],\n",
+ " n_steps=nb_samples)\n",
+ "\n",
+ "# Compiling the function\n",
+ "f = pytensor.function(inputs=[probabilities, nb_samples], outputs=samples, updates=updates)\n",
+ "\n",
+ "# Testing the function\n",
+ "test_probs = np.asarray([0.6, 0.3, 0.1], dtype=floatX)\n",
+ "\n",
+ "for i in range(10):\n",
+ " print(f(test_probs, 2))"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[0 1]\n",
+ "[0 1]\n",
+ "[2 1]\n",
+ "[2 0]\n",
+ "[0 1]\n",
+ "[0 1]\n",
+ "[0 1]\n",
+ "[0 1]\n",
+ "[0 1]\n",
+ "[0 1]\n"
+ ]
+ }
+ ],
+ "execution_count": 18
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "## Authors\n",
+ "\n",
+ "- Authored by Pascal Lamblin in Feburary 2016\n",
+ "- Updated by Jesse Grabowski in January 2025"
+ ]
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "## References\n",
+ "\n",
+ ":::{bibliography} :filter: docname in docnames"
+ ]
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Watermark "
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-01-11T07:50:45.845462Z",
+ "start_time": "2025-01-11T07:50:45.809393Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%load_ext watermark\n",
+ "%watermark -n -u -v -iv -w -p pytensor"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The watermark extension is already loaded. To reload it, use:\n",
+ " %reload_ext watermark\n",
+ "Last updated: Sat Jan 11 2025\n",
+ "\n",
+ "Python implementation: CPython\n",
+ "Python version : 3.12.0\n",
+ "IPython version : 8.31.0\n",
+ "\n",
+ "pytensor: 2.26.4+16.g8be5c5323.dirty\n",
+ "\n",
+ "numpy : 1.26.4\n",
+ "pytensor: 2.26.4+16.g8be5c5323.dirty\n",
+ "sys : 3.12.0 | packaged by conda-forge | (main, Oct 3 2023, 08:43:22) [GCC 12.3.0]\n",
+ "\n",
+ "Watermark: 2.5.0\n",
+ "\n"
+ ]
+ }
+ ],
+ "execution_count": 20
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ ":::{include} ../page_footer.md \n",
+ ":::"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/doc/images/PyTensor.png b/doc/images/PyTensor.png
new file mode 100644
index 0000000000..e6097693af
Binary files /dev/null and b/doc/images/PyTensor.png differ
diff --git a/doc/images/PyTensor_logo.png b/doc/images/PyTensor_logo.png
new file mode 100644
index 0000000000..c8947735de
Binary files /dev/null and b/doc/images/PyTensor_logo.png differ
diff --git a/doc/images/binder.svg b/doc/images/binder.svg
new file mode 100644
index 0000000000..327f6b639a
--- /dev/null
+++ b/doc/images/binder.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/doc/images/colab.svg b/doc/images/colab.svg
new file mode 100644
index 0000000000..c08066ee33
--- /dev/null
+++ b/doc/images/colab.svg
@@ -0,0 +1 @@
+
diff --git a/doc/images/github.svg b/doc/images/github.svg
new file mode 100644
index 0000000000..e02d8ed55b
--- /dev/null
+++ b/doc/images/github.svg
@@ -0,0 +1 @@
+View On GitHubView On GitHub
\ No newline at end of file
diff --git a/doc/index.rst b/doc/index.rst
index ac5bc0876c..a70a28df82 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -80,6 +80,7 @@ Community
introduction
user_guide
API
+ Examples
Contributing
.. _Theano: https://github.com/Theano/Theano
diff --git a/doc/library/index.rst b/doc/library/index.rst
index 6a05a5a7bf..08a5b51c34 100644
--- a/doc/library/index.rst
+++ b/doc/library/index.rst
@@ -22,7 +22,6 @@ Modules
gradient
misc/pkl_utils
printing
- sandbox/index
scalar/index
scan
sparse/index
diff --git a/doc/library/misc/pkl_utils.rst b/doc/library/misc/pkl_utils.rst
index 0299d15204..f22e5e8bd7 100644
--- a/doc/library/misc/pkl_utils.rst
+++ b/doc/library/misc/pkl_utils.rst
@@ -9,10 +9,6 @@
from pytensor.misc.pkl_utils import *
-.. autofunction:: pytensor.misc.pkl_utils.dump
-
-.. autofunction:: pytensor.misc.pkl_utils.load
-
.. autoclass:: pytensor.misc.pkl_utils.StripPickler
.. seealso::
diff --git a/doc/library/sandbox/index.rst b/doc/library/sandbox/index.rst
deleted file mode 100644
index b4012cd9df..0000000000
--- a/doc/library/sandbox/index.rst
+++ /dev/null
@@ -1,16 +0,0 @@
-
-.. _libdoc_sandbox:
-
-==============================================================
-:mod:`sandbox` -- Experimental Code
-==============================================================
-
-.. module:: sandbox
- :platform: Unix, Windows
- :synopsis: Experimental code
-.. moduleauthor:: LISA
-
-.. toctree::
- :maxdepth: 1
-
- linalg
diff --git a/doc/library/sandbox/linalg.rst b/doc/library/sandbox/linalg.rst
deleted file mode 100644
index 9ee5fe9f51..0000000000
--- a/doc/library/sandbox/linalg.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-.. ../../../../pytensor/sandbox/linalg/ops.py
-.. ../../../../pytensor/sandbox/linalg
-
-.. _libdoc_sandbox_linalg:
-
-===================================================================
-:mod:`sandbox.linalg` -- Linear Algebra Ops
-===================================================================
-
-.. module:: sandbox.linalg
- :platform: Unix, Windows
- :synopsis: Linear Algebra Ops
-.. moduleauthor:: LISA
-
-API
-===
-
-.. automodule:: pytensor.sandbox.linalg.ops
- :members:
diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst
index 50da46449a..8d22c1e577 100644
--- a/doc/library/tensor/basic.rst
+++ b/doc/library/tensor/basic.rst
@@ -477,7 +477,7 @@ them perfectly, but a `dscalar` otherwise.
you'll want to call.
-.. autoclass:: pytensor.tensor.var._tensor_py_operators
+.. autoclass:: pytensor.tensor.variable._tensor_py_operators
:members:
This mix-in class adds convenient attributes, methods, and support
diff --git a/doc/robots.txt b/doc/robots.txt
new file mode 100644
index 0000000000..73cf5dba3b
--- /dev/null
+++ b/doc/robots.txt
@@ -0,0 +1,3 @@
+User-agent: *
+
+Sitemap: https://pytensor.readthedocs.io/en/latest/sitemap.xml
diff --git a/doc/tutorial/loading_and_saving.rst b/doc/tutorial/loading_and_saving.rst
index dc6eb9b097..d099ecb026 100644
--- a/doc/tutorial/loading_and_saving.rst
+++ b/doc/tutorial/loading_and_saving.rst
@@ -145,7 +145,7 @@ might not have PyTensor installed, who are using a different Python version, or
you are planning to save your model for a long time (in which case version
mismatches might make it difficult to unpickle objects).
-See :func:`pytensor.misc.pkl_utils.dump` and :func:`pytensor.misc.pkl_utils.load`.
+See :meth:`pytensor.misc.pkl_utils.StripPickler.dump` and :meth:`pytensor.misc.pkl_utils.StripPickler.load`.
Long-Term Serialization
diff --git a/environment.yml b/environment.yml
index 4b213fd851..1571ae0d11 100644
--- a/environment.yml
+++ b/environment.yml
@@ -43,6 +43,10 @@ dependencies:
- ipython
- pymc-sphinx-theme
- sphinx-design
+ - myst-nb
+ - matplotlib
+ - watermark
+
# code style
- ruff
# developer tools
diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py
index d797504ae6..9c2eef5049 100644
--- a/pytensor/graph/utils.py
+++ b/pytensor/graph/utils.py
@@ -107,8 +107,6 @@ def add_tag_trace(thing: T, user_line: int | None = None) -> T:
"pytensor\\graph\\",
"pytensor/scalar/basic.py",
"pytensor\\scalar\\basic.py",
- "pytensor/sandbox/",
- "pytensor\\sandbox\\",
"pytensor/scan/",
"pytensor\\scan\\",
"pytensor/sparse/",
diff --git a/scripts/generate_gallery.py b/scripts/generate_gallery.py
new file mode 100644
index 0000000000..5cd78d8494
--- /dev/null
+++ b/scripts/generate_gallery.py
@@ -0,0 +1,185 @@
+"""
+Sphinx plugin to run generate a gallery for notebooks
+
+Modified from the pymc project, which modified the seaborn project, which modified the mpld3 project.
+"""
+
+import base64
+import json
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib
+
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import sphinx
+from matplotlib import image
+
+
+logger = sphinx.util.logging.getLogger(__name__)
+
+DOC_SRC = Path(__file__).resolve().parent.parent
+DEFAULT_IMG_LOC = DOC_SRC / "doc" / "images" / "PyTensor_logo.png"
+
+external_nbs = {}
+
+HEAD = """
+Example Gallery
+===============
+
+.. toctree::
+ :hidden:
+
+"""
+
+SECTION_TEMPLATE = """
+.. _{section_id}:
+
+{section_title}
+{underlines}
+
+.. grid:: 1 2 3 3
+ :gutter: 4
+
+"""
+
+ITEM_TEMPLATE = """
+ .. grid-item-card:: :doc:`{doc_name}`
+ :img-top: {image}
+ :link: {doc_reference}
+ :link-type: {link_type}
+ :shadow: none
+"""
+
+folder_title_map = {
+ "introduction": "Introduction",
+ "rewrites": "Graph Rewriting",
+ "scan": "Looping in Pytensor",
+}
+
+
+def create_thumbnail(infile, width=275, height=275, cx=0.5, cy=0.5, border=4):
+ """Overwrites `infile` with a new file of the given size"""
+ im = image.imread(infile)
+ rows, cols = im.shape[:2]
+ size = min(rows, cols)
+ if size == cols:
+ xslice = slice(0, size)
+ ymin = min(max(0, int(cx * rows - size // 2)), rows - size)
+ yslice = slice(ymin, ymin + size)
+ else:
+ yslice = slice(0, size)
+ xmin = min(max(0, int(cx * cols - size // 2)), cols - size)
+ xslice = slice(xmin, xmin + size)
+ thumb = im[yslice, xslice]
+ thumb[:border, :, :3] = thumb[-border:, :, :3] = 0
+ thumb[:, :border, :3] = thumb[:, -border:, :3] = 0
+
+ dpi = 100
+ fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
+
+ ax = fig.add_axes([0, 0, 1, 1], aspect="auto", frameon=False, xticks=[], yticks=[])
+ ax.imshow(thumb, aspect="auto", resample=True, interpolation="bilinear")
+ fig.savefig(infile, dpi=dpi)
+ plt.close(fig)
+ return fig
+
+
+class NotebookGenerator:
+ """Tools for generating an example page from a file"""
+
+ def __init__(self, filename, root_dir, folder):
+ self.folder = folder
+
+ self.basename = Path(filename).name
+ self.stripped_name = Path(filename).stem
+ self.image_dir = Path(root_dir) / "doc" / "_thumbnails" / folder
+ self.png_path = self.image_dir / f"{self.stripped_name}.png"
+
+ with filename.open(encoding="utf-8") as fid:
+ self.json_source = json.load(fid)
+ self.default_image_loc = DEFAULT_IMG_LOC
+
+ def extract_preview_pic(self):
+ """By default, just uses the last image in the notebook."""
+ pic = None
+ for cell in self.json_source["cells"]:
+ for output in cell.get("outputs", []):
+ if "image/png" in output.get("data", []):
+ pic = output["data"]["image/png"]
+ if pic is not None:
+ return base64.b64decode(pic)
+ return None
+
+ def gen_previews(self):
+ preview = self.extract_preview_pic()
+ if preview is not None:
+ with self.png_path.open("wb") as buff:
+ buff.write(preview)
+ else:
+ logger.warning(
+ f"Didn't find any pictures in {self.basename}",
+ type="thumbnail_extractor",
+ )
+ shutil.copy(self.default_image_loc, self.png_path)
+ create_thumbnail(self.png_path)
+
+
+def main(app):
+ logger.info("Starting thumbnail extractor.")
+
+ working_dir = Path.cwd()
+ os.chdir(app.builder.srcdir)
+
+ file = [HEAD]
+
+ for folder, title in folder_title_map.items():
+ file.append(
+ SECTION_TEMPLATE.format(
+ section_title=title, section_id=folder, underlines="-" * len(title)
+ )
+ )
+
+ thumbnail_dir = Path("_thumbnails") / folder
+ if not thumbnail_dir.exists():
+ Path.mkdir(thumbnail_dir, parents=True)
+
+ if folder in external_nbs.keys():
+ file += [
+ ITEM_TEMPLATE.format(
+ doc_name=descr["doc_name"],
+ image=descr["image"],
+ doc_reference=descr["doc_reference"],
+ link_type=descr["link_type"],
+ )
+ for descr in external_nbs[folder]
+ ]
+
+ nb_paths = sorted(Path("gallery", folder).glob("*.ipynb"))
+
+ for nb_path in nb_paths:
+ nbg = NotebookGenerator(
+ filename=nb_path, root_dir=Path(".."), folder=folder
+ )
+ nbg.gen_previews()
+
+ file.append(
+ ITEM_TEMPLATE.format(
+ doc_name=Path(folder) / nbg.stripped_name,
+ image="/" + str(nbg.png_path),
+ doc_reference=Path(folder) / nbg.stripped_name,
+ link_type="doc",
+ )
+ )
+
+ with Path("gallery", "gallery.rst").open("w", encoding="utf-8") as f:
+ f.write("\n".join(file))
+
+ os.chdir(working_dir)
+
+
+def setup(app):
+ app.connect("builder-inited", main)