Skip to content

Add initial support for PyTorch backend #764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27e2526
Add pytorch support for some basic Ops
HarshvirSandhu May 13, 2024
629d00b
update variable names, docstrings
HarshvirSandhu May 13, 2024
3eceb56
Avoid numpy conversion of torch Tensors
HarshvirSandhu May 17, 2024
3cde964
Fix typify and CheckAndRaise
HarshvirSandhu May 17, 2024
c003aa5
Fix Elemwise Ops
HarshvirSandhu May 17, 2024
8dc406e
Fix Scalar Ops
HarshvirSandhu May 17, 2024
a8f6ddb
Fix ruff-format
HarshvirSandhu May 17, 2024
9d535f5
Initial setup for pytorch tests
HarshvirSandhu May 23, 2024
c5600da
Fix mode parameters for pytorch
HarshvirSandhu May 23, 2024
54b6248
Prevent conversion of scalars to numpy
HarshvirSandhu May 23, 2024
19454b3
Update TensorConstantSignature and map dtypes to Tensor types
HarshvirSandhu May 23, 2024
92d7114
Add tests for basic ops
HarshvirSandhu May 23, 2024
5aae0e5
Remove torch from user facing API
HarshvirSandhu May 29, 2024
8c174dd
Add function to convert numpy arrays to pytorch tensors
HarshvirSandhu May 29, 2024
0977c3a
Avoid copy when converting to tensor
HarshvirSandhu May 29, 2024
1c23825
Fix tests
HarshvirSandhu May 29, 2024
c9195a8
Remove dispatches that are not tested
HarshvirSandhu May 31, 2024
b07805c
set path for pytorch tests
HarshvirSandhu May 31, 2024
9e8d3fc
Remove tensorflow probability from yml
HarshvirSandhu Jun 4, 2024
a2d3afa
Add checks for runtime broadcasting
HarshvirSandhu Jun 4, 2024
a577a80
Remove IfElse
HarshvirSandhu Jun 4, 2024
499a174
Remove dev notebook
HarshvirSandhu Jun 12, 2024
2826613
Fix check and raise
HarshvirSandhu Jun 12, 2024
62ffcec
Fix compare_pytorch_and_py
HarshvirSandhu Jun 12, 2024
acdbba1
Fix DimShuffle
HarshvirSandhu Jun 12, 2024
2519c65
Add tests for Elemwise operations
HarshvirSandhu Jun 12, 2024
eb6d5c2
Fix test for CheckAndRaise
HarshvirSandhu Jun 14, 2024
9f02a4f
Remove duplicate function
HarshvirSandhu Jun 14, 2024
caf2965
Remove device from pytorch_typify
HarshvirSandhu Jun 15, 2024
bf87eb9
Merge branch 'main' of https://github.com/HarshvirSandhu/pytensor int…
HarshvirSandhu Jun 15, 2024
2c27683
Solve merge conflict
HarshvirSandhu Jun 15, 2024
c603c6b
Use micromamba for pytorch install
HarshvirSandhu Jun 15, 2024
3f17107
Fix pytorch linker
HarshvirSandhu Jun 16, 2024
e850d8d
Fix typify and deepcopy
HarshvirSandhu Jun 16, 2024
e682fc4
Parametrize device in all tests
HarshvirSandhu Jun 16, 2024
bf4cf92
Install torch with cuda
HarshvirSandhu Jun 16, 2024
899e7f9
Fix test_pytorch_FunctionGraph_once
HarshvirSandhu Jun 16, 2024
04d2935
Remove device argument from test
HarshvirSandhu Jun 16, 2024
8ec7661
remove device from elemwise tests and add assertions
HarshvirSandhu Jun 17, 2024
bb7df41
skip tests if cuda is not available
HarshvirSandhu Jun 17, 2024
0441cf2
Fix tests
HarshvirSandhu Jun 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
float32: [0,1]
install-numba: [0]
install-jax: [0]
install-torch: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -116,6 +117,11 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
python-version: "3.10"
fast-compile: 0
float32: 0
# part: "tests/link/pytorch"
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -143,6 +149,7 @@ jobs:
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch && pip install tensorflow-probability; fi
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand All @@ -151,6 +158,7 @@ jobs:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}

- name: Run tests
shell: bash -l {0}
Expand Down Expand Up @@ -195,7 +203,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytorch pytest-benchmark
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand Down Expand Up @@ -264,3 +272,4 @@ jobs:
directory: ./coverage/
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}

15 changes: 15 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker


Expand All @@ -47,6 +48,7 @@
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
}

Expand Down Expand Up @@ -462,6 +464,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
],
),
)
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
],
),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
Expand All @@ -476,6 +490,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
}

instantiated_default_mode = None
Expand Down
19 changes: 19 additions & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify

# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar

# import pytensor.link.jax.dispatch.tensor_basic
# import pytensor.link.jax.dispatch.subtensor
# import pytensor.link.jax.dispatch.shape
# import pytensor.link.jax.dispatch.extra_ops
# import pytensor.link.jax.dispatch.nlinalg
# import pytensor.link.jax.dispatch.slinalg
# import pytensor.link.jax.dispatch.random
import pytensor.link.pytorch.dispatch.elemwise
# import pytensor.link.jax.dispatch.scan
# import pytensor.link.jax.dispatch.sparse
# import pytensor.link.jax.dispatch.blockwise

# isort: on
105 changes: 105 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import warnings
from functools import singledispatch

import torch

from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
if dtype is None:
return data
else:
return torch.tensor(data, dtype=dtype)


@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
# if len(data.shape) == 0:
# return data.item()
return torch.tensor(data, dtype=dtype)


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}")


@pytorch_funcify.register(FunctionGraph)
def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
pytorch_funcify,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
)


@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs

def ifelse(cond, *args, n_outs=n_outs):
res = torch.where(
cond,
args[:n_outs][0],
args[n_outs:][0],
)
return res

return ifelse


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
def assert_fn(x, *conditions):
for cond in conditions:
assert cond.item()
return x

return assert_fn


def pytorch_safe_copy(x):
try:
res = x.clone()
except NotImplementedError:
# warnings.warn(
# "`jnp.copy` is not implemented yet. Using the object's `copy` method."
# )
if hasattr(x, "copy"):
res = torch.tensor(x.copy())
else:
warnings.warn(f"Object has no `copy` method: {x}")
res = x

return res


@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return pytorch_safe_copy(x)

return deepcopyop


@pytorch_funcify.register(ViewOp)
def pytorch_funcify_ViewOp(op, **kwargs):
def viewop(x):
return x

return viewop
68 changes: 68 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
# Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs)))
return base_fn(*inputs)

return elemwise_fn


@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.transpose(x, *op.transposition)

shape = list(res.shape[: len(op.shuffle)])

for augm in op.augment:
shape.insert(augm, 1)

res = torch.reshape(res, shape)

if not op.inplace:
res = res.clone()

return res

return dimshuffle


@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis

def softmax(x):
return torch.nn.functional.softmax(x, dim=axis)

return softmax


@pytorch_funcify.register(SoftmaxGrad)
def pytorch_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis

def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdims=True) * sm

return softmax_grad


@pytorch_funcify.register(LogSoftmax)
def pytorch_funcify_LogSoftmax(op, **kwargs):
axis = op.axis

def log_softmax(x):
return torch.nn.functional.log_softmax(x, dim=axis)

return log_softmax
Loading
Loading