|
15 | 15 | from torch.utils.dlpack import from_dlpack, to_dlpack
|
16 | 16 |
|
17 | 17 |
|
| 18 | +# Wraps a tensor, exposing only DLPack methods: |
| 19 | +# - __dlpack__ |
| 20 | +# - __dlpack_device__ |
| 21 | +# |
| 22 | +# This is used for guaranteeing we are going through the DLPack method, and not |
| 23 | +# something else, e.g.: CUDA array interface, buffer protocol, etc. |
| 24 | +class TensorDLPackWrapper: |
| 25 | + def __init__(self, tensor): |
| 26 | + self.tensor = tensor |
| 27 | + |
| 28 | + def __dlpack__(self, *args, **kwargs): |
| 29 | + return self.tensor.__dlpack__(*args, **kwargs) |
| 30 | + |
| 31 | + def __dlpack_device__(self, *args, **kwargs): |
| 32 | + return self.tensor.__dlpack_device__(*args, **kwargs) |
| 33 | + |
| 34 | + |
18 | 35 | class TestTorchDlPack(TestCase):
|
19 | 36 | exact_dtype = True
|
20 | 37 |
|
@@ -251,6 +268,19 @@ def test_dlpack_normalize_strides(self):
|
251 | 268 | # gh-83069, make sure __dlpack__ normalizes strides
|
252 | 269 | self.assertEqual(z.stride(), (1,))
|
253 | 270 |
|
| 271 | + @skipMeta |
| 272 | + @onlyNativeDeviceTypes |
| 273 | + def test_automatically_select_in_creation(self, device): |
| 274 | + # Create a new tensor, and wrap it using TensorDLPackWrapper. |
| 275 | + tensor = torch.rand(10) |
| 276 | + wrap = TensorDLPackWrapper(tensor) |
| 277 | + # Create a new tensor from the wrapper. |
| 278 | + # This should identify that the wrapper class provides the DLPack methods |
| 279 | + # and use them for creating the new tensor, instead of iterating element |
| 280 | + # by element. |
| 281 | + new_tensor = torch.tensor(wrap) |
| 282 | + self.assertEqual(tensor, new_tensor) |
| 283 | + |
254 | 284 |
|
255 | 285 | instantiate_device_type_tests(TestTorchDlPack, globals())
|
256 | 286 |
|
|
0 commit comments