diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..2dcf2dfc36 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -123,7 +123,10 @@ def arange(start, stop, step): def pytorch_funcify_Join(op, **kwargs): def join(axis, *tensors): # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] + tensors = [ + torch.tensor(tensor) if not torch.is_tensor(tensor) else tensor + for tensor in tensors + ] return torch.cat(tensors, dim=axis) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..48675c5a4d 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,7 +51,8 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = torch.compile(fn) + with torch.no_grad(): + self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -62,7 +63,9 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + with torch.no_grad(): + ins = (pytorch_typify(inp) for inp in inputs) + outs = self.fn(*ins, **kwargs) # unset attrs for n, _ in self.gen_functors: