Skip to content

Commit 565a53d

Browse files
ysiraichiwjakob
authored andcommitted
Use DLPack for creating tensors out of custom classes, when available. (#138697)
Fixes #120614 Takes over #120615 In summary, this PR: - Adds a `__dlpack__` attribute check in the tensor creation path (i.e. [`internal_new_from_data` @ tensor_new.cpp](https://github.com/pytorch/pytorch/blob/cdfe1bffd16bdd28adbe5518038f68e6ac45de8d/torch/csrc/utils/tensor_new.cpp#L266)) - Creates the tensor by using the DLPack machinery, instead of an element-by-element copy - No changes since #120615 - Adds a test, making sure the DLPack machinery is used - Wraps a tensor in a fresh `TensorDLPackWrapper` class that implements only the DLPack methods - Creates a new tensor from an instance of `TensorDLPackWrapper` Pull Request resolved: #138697 Approved by: https://github.com/ezyang Co-authored-by: Wenzel Jakob <[email protected]>
1 parent e299193 commit 565a53d

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

test/test_dlpack.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,23 @@
1515
from torch.utils.dlpack import from_dlpack, to_dlpack
1616

1717

18+
# Wraps a tensor, exposing only DLPack methods:
19+
# - __dlpack__
20+
# - __dlpack_device__
21+
#
22+
# This is used for guaranteeing we are going through the DLPack method, and not
23+
# something else, e.g.: CUDA array interface, buffer protocol, etc.
24+
class TensorDLPackWrapper:
25+
def __init__(self, tensor):
26+
self.tensor = tensor
27+
28+
def __dlpack__(self, *args, **kwargs):
29+
return self.tensor.__dlpack__(*args, **kwargs)
30+
31+
def __dlpack_device__(self, *args, **kwargs):
32+
return self.tensor.__dlpack_device__(*args, **kwargs)
33+
34+
1835
class TestTorchDlPack(TestCase):
1936
exact_dtype = True
2037

@@ -251,6 +268,19 @@ def test_dlpack_normalize_strides(self):
251268
# gh-83069, make sure __dlpack__ normalizes strides
252269
self.assertEqual(z.stride(), (1,))
253270

271+
@skipMeta
272+
@onlyNativeDeviceTypes
273+
def test_automatically_select_in_creation(self, device):
274+
# Create a new tensor, and wrap it using TensorDLPackWrapper.
275+
tensor = torch.rand(10)
276+
wrap = TensorDLPackWrapper(tensor)
277+
# Create a new tensor from the wrapper.
278+
# This should identify that the wrapper class provides the DLPack methods
279+
# and use them for creating the new tensor, instead of iterating element
280+
# by element.
281+
new_tensor = torch.tensor(wrap)
282+
self.assertEqual(tensor, new_tensor)
283+
254284

255285
instantiate_device_type_tests(TestTorchDlPack, globals())
256286

torch/csrc/utils/tensor_new.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,23 @@ Tensor internal_new_from_data(
345345
}
346346
#endif
347347

348+
if (PyObject_HasAttrString(data, "__dlpack__")) {
349+
py::object tensor_o =
350+
py::module::import("torch").attr("utils").attr("dlpack").attr(
351+
"from_dlpack")(py::handle(data));
352+
Tensor tensor = py::cast<Tensor>(tensor_o);
353+
const auto& inferred_scalar_type =
354+
type_inference ? tensor.scalar_type() : scalar_type;
355+
auto device = device_opt.has_value() ? *device_opt : tensor.device();
356+
pybind11::gil_scoped_release no_gil;
357+
maybe_initialize_device(device);
358+
return tensor.to(
359+
device,
360+
inferred_scalar_type,
361+
/*non_blocking=*/false,
362+
/*copy=*/copy_variables);
363+
}
364+
348365
auto device = device_opt.has_value() ? *device_opt : options.device();
349366

350367
auto sizes = compute_sizes(data, scalar_type);

0 commit comments

Comments
 (0)