Skip to content

Commit e214fde

Browse files
committed
support copy in from_dlpack
1 parent 83420d2 commit e214fde

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

src/array_api_stubs/_draft/array_object.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def __dlpack__(
293293
*,
294294
stream: Optional[Union[int, Any]] = None,
295295
max_version: Optional[tuple[int, int]] = None,
296+
dl_device: Optional[Tuple[Enum, int]] = None,
297+
copy: Optional[bool] = False
296298
) -> PyCapsule:
297299
"""
298300
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
@@ -339,6 +341,17 @@ def __dlpack__(
339341
if it does support that), or of a different version.
340342
This means the consumer must verify the version even when
341343
`max_version` is passed.
344+
dl_device: Optional[Tuple[Enum, int]]
345+
The DLPack device type. Default is ``None``, meaning the exported capsule
346+
should be on the same device as ``self`` is. When specified, the format
347+
must follow that of the return value of :meth:`array.__dlpack_device__`.
348+
If the device type cannot be handled by the producer, this function must
349+
raise `BufferError`.
350+
copy: Optional[bool]
351+
Whether or not a copy should be made. Default is ``False`` to enable
352+
zero-copy data exchange. However, a user can request a copy to be made
353+
by the producer (through the consumer's :func:`~array_api.from_dlpack`)
354+
to move data across the library (and/or device) boundary.
342355
343356
Returns
344357
-------
@@ -394,7 +407,7 @@ def __dlpack__(
394407
# here to tell users that the consumer's max_version is too
395408
# old to allow the data exchange to happen.
396409
397-
And this logic for the consumer in ``from_dlpack``:
410+
And this logic for the consumer in :func:`~array_api.from_dlpack`:
398411
399412
.. code:: python
400413
@@ -409,7 +422,7 @@ def __dlpack__(
409422
Added BufferError.
410423
411424
.. versionchanged:: 2023.12
412-
Added the ``max_version`` keyword.
425+
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
413426
"""
414427

415428
def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
@@ -436,6 +449,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
436449
METAL = 8
437450
VPI = 9
438451
ROCM = 10
452+
CUDA_MANAGED = 13
453+
ONE_API = 14
439454
"""
440455

441456
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:

src/array_api_stubs/_draft/creation_functions.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
from ._types import (
22+
Any,
2223
List,
2324
NestedSequence,
2425
Optional,
@@ -214,19 +215,36 @@ def eye(
214215
"""
215216

216217

217-
def from_dlpack(x: object, /) -> array:
218+
def from_dlpack(
219+
x: object, /, *,
220+
device: Optional[device] = None,
221+
copy: Optional[bool] = False,
222+
) -> Union[array, Any]:
218223
"""
219224
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.
220225
221226
Parameters
222227
----------
223228
x: object
224229
input (array) object.
230+
device: Optional[device]
231+
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array device must be inferred from ``x``. Default: ``None``.
232+
233+
The v2023.12 standard only mandates that a compliant library must offer a way for ``from_dlpack`` to create an array on CPU (using
234+
the library-chosen way to represent the CPU device - ``kDLCPU`` in DLPack - e.g. a ``"CPU"`` string or a ``Device("CPU")`` object).
235+
If the compliant library does not support the CPU device and needs to outsource to another (compliant) array library, it may do so
236+
with a clear user documentation and/or run-time warning. If a copy must be made to enable this, and ``copy`` is set to ``False``,
237+
the function must raise ``ValueError``.
238+
239+
Other kinds of devices will be considered for standardization in a future version.
240+
copy: Optional[bool]
241+
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy and must raise a ``BufferError`` in case a copy would be necessary (e.g. the producer disallows views). Default: ``False``.
225242
226243
Returns
227244
-------
228-
out: array
229-
an array containing the data in `x`.
245+
out: Union[array, Any]
246+
an array containing the data in ``x``. In the case that the compliant library does not support the given ``device`` out of box
247+
and must oursource to another (compliant) library, the output will be that library's compliant array object.
230248
231249
.. admonition:: Note
232250
:class: note
@@ -238,9 +256,9 @@ def from_dlpack(x: object, /) -> array:
238256
BufferError
239257
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
240258
may raise ``BufferError`` when the data cannot be exported as DLPack
241-
(e.g., incompatible dtype or strides). It may also raise other errors
259+
(e.g., incompatible dtype, strides, or device). It may also raise other errors
242260
when export fails for other reasons (e.g., not enough memory available
243-
to materialize the data). ``from_dlpack`` must propagate such
261+
to materialize the data, a copy must made, etc). ``from_dlpack`` must propagate such
244262
exceptions.
245263
AttributeError
246264
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
@@ -251,6 +269,9 @@ def from_dlpack(x: object, /) -> array:
251269
-----
252270
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
253271
order to handle DLPack versioning correctly.
272+
273+
.. versionchanged:: 2023.12
274+
Added device and copy support.
254275
"""
255276

256277

0 commit comments

Comments
 (0)