Skip to content

Commit 379b939

Browse files
Merge pull request #1653 from IntelPython/document-race-condition-in-put
Addresses gh-1360 with documentation
2 parents ef5a751 + 18df92a commit 379b939

File tree

1 file changed

+102
-54
lines changed

1 file changed

+102
-54
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 102 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,29 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
4141
Takes elements from an array along a given axis at given indices.
4242
4343
Args:
44-
x (usm_ndarray):
45-
The array that elements will be taken from.
46-
indices (usm_ndarray):
47-
One-dimensional array of indices.
48-
axis:
49-
The axis along which the values will be selected.
50-
If ``x`` is one-dimensional, this argument is optional.
51-
Default: ``None``.
52-
mode:
53-
How out-of-bounds indices will be handled.
54-
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
55-
negative indices.
56-
``"clip"`` - clips indices to (0 <= i < n)
57-
Default: ``"wrap"``.
44+
x (usm_ndarray):
45+
The array that elements will be taken from.
46+
indices (usm_ndarray):
47+
One-dimensional array of indices.
48+
axis (int, optional):
49+
The axis along which the values will be selected.
50+
If ``x`` is one-dimensional, this argument is optional.
51+
Default: ``None``.
52+
mode (str, optional):
53+
How out-of-bounds indices will be handled. Possible values
54+
are:
55+
56+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
57+
negative indices.
58+
- ``"clip"``: clips indices to (``0 <= i < n``).
59+
60+
Default: ``"wrap"``.
5861
5962
Returns:
6063
usm_ndarray:
61-
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
62-
filled with elements from x.
64+
Array with shape
65+
``x.shape[:axis] + indices.shape + x.shape[axis + 1:]``
66+
filled with elements from ``x``.
6367
"""
6468
if not isinstance(x, dpt.usm_ndarray):
6569
raise TypeError(
@@ -128,30 +132,71 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
128132
Puts values into an array along a given axis at given indices.
129133
130134
Args:
131-
x (usm_ndarray):
132-
The array the values will be put into.
133-
indices (usm_ndarray)
134-
One-dimensional array of indices.
135-
136-
Note that if indices are not unique, a race
137-
condition will result, and the value written to
138-
``x`` will not be deterministic.
139-
:py:func:`dpctl.tensor.unique` can be used to
140-
guarantee unique elements in ``indices``.
141-
vals:
142-
Array of values to be put into ``x``.
143-
Must be broadcastable to the result shape
144-
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
145-
axis:
146-
The axis along which the values will be placed.
147-
If ``x`` is one-dimensional, this argument is optional.
148-
Default: ``None``.
149-
mode:
150-
How out-of-bounds indices will be handled.
151-
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
152-
negative indices.
153-
``"clip"`` - clips indices to (0 <= i < n)
154-
Default: ``"wrap"``.
135+
x (usm_ndarray):
136+
The array the values will be put into.
137+
indices (usm_ndarray):
138+
One-dimensional array of indices.
139+
vals (usm_ndarray):
140+
Array of values to be put into ``x``.
141+
Must be broadcastable to the result shape
142+
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
143+
axis (int, optional):
144+
The axis along which the values will be placed.
145+
If ``x`` is one-dimensional, this argument is optional.
146+
Default: ``None``.
147+
mode (str, optional):
148+
How out-of-bounds indices will be handled. Possible values
149+
are:
150+
151+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
152+
negative indices.
153+
- ``"clip"``: clips indices to (``0 <= i < n``).
154+
155+
Default: ``"wrap"``.
156+
157+
.. note::
158+
159+
If input array ``indices`` contains duplicates, a race condition
160+
occurs, and the value written into corresponding positions in ``x``
161+
may vary from run to run. Preserving sequential semantics in handing
162+
the duplicates to achieve deterministic behavior requires additional
163+
work, e.g.
164+
165+
:Example:
166+
167+
.. code-block:: python
168+
169+
from dpctl import tensor as dpt
170+
171+
def put_vec_duplicates(vec, ind, vals):
172+
"Put values into vec, handling possible duplicates in ind"
173+
assert vec.ndim, ind.ndim, vals.ndim == 1, 1, 1
174+
175+
# find positions of last occurences of each
176+
# unique index
177+
ind_flipped = dpt.flip(ind)
178+
ind_uniq = dpt.unique_all(ind_flipped).indices
179+
has_dups = len(ind) != len(ind_uniq)
180+
181+
if has_dups:
182+
ind_uniq = dpt.subtract(vec.size - 1, ind_uniq)
183+
ind = dpt.take(ind, ind_uniq)
184+
vals = dpt.take(vals, ind_uniq)
185+
186+
dpt.put(vec, ind, vals)
187+
188+
n = 512
189+
ind = dpt.concat((dpt.arange(n), dpt.arange(n, -1, step=-1)))
190+
x = dpt.zeros(ind.size, dtype="int32")
191+
vals = dpt.arange(ind.size, dtype=x.dtype)
192+
193+
# Values corresponding to last positions of
194+
# duplicate indices are written into the vector x
195+
put_vec_duplicates(x, ind, vals)
196+
197+
parts = (vals[-1:-n-2:-1], dpt.zeros(n, dtype=x.dtype))
198+
expected = dpt.concat(parts)
199+
assert dpt.all(x == expected)
155200
"""
156201
if not isinstance(x, dpt.usm_ndarray):
157202
raise TypeError(
@@ -237,22 +282,24 @@ def extract(condition, arr):
237282
238283
Returns the elements of an array that satisfies the condition.
239284
240-
If `condition` is boolean ``dpctl.tensor.extract`` is
285+
If ``condition`` is boolean ``dpctl.tensor.extract`` is
241286
equivalent to ``arr[condition]``.
242287
243288
Note that ``dpctl.tensor.place`` does the opposite of
244289
``dpctl.tensor.extract``.
245290
246291
Args:
247292
conditions (usm_ndarray):
248-
An array whose non-zero or True entries indicate the element
249-
of `arr` to extract.
293+
An array whose non-zero or ``True`` entries indicate the element
294+
of ``arr`` to extract.
295+
250296
arr (usm_ndarray):
251-
Input array of the same size as `condition`.
297+
Input array of the same size as ``condition``.
252298
253299
Returns:
254300
usm_ndarray:
255-
Rank 1 array of values from `arr` where `condition` is True.
301+
Rank 1 array of values from ``arr`` where ``condition`` is
302+
``True``.
256303
"""
257304
if not isinstance(condition, dpt.usm_ndarray):
258305
raise TypeError(
@@ -280,20 +327,20 @@ def place(arr, mask, vals):
280327
281328
Change elements of an array based on conditional and input values.
282329
283-
If `mask` is boolean ``dpctl.tensor.place`` is
330+
If ``mask`` is boolean ``dpctl.tensor.place`` is
284331
equivalent to ``arr[condition] = vals``.
285332
286333
Args:
287334
arr (usm_ndarray):
288335
Array to put data into.
289336
mask (usm_ndarray):
290-
Boolean mask array. Must have the same size as `arr`.
337+
Boolean mask array. Must have the same size as ``arr``.
291338
vals (usm_ndarray, sequence):
292-
Values to put into `arr`. Only the first N elements are
293-
used, where N is the number of True values in `mask`. If
294-
`vals` is smaller than N, it will be repeated, and if
295-
elements of `arr` are to be masked, this sequence must be
296-
non-empty. Array `vals` must be one dimensional.
339+
Values to put into ``arr``. Only the first N elements are
340+
used, where N is the number of True values in ``mask``. If
341+
``vals`` is smaller than N, it will be repeated, and if
342+
elements of ``arr`` are to be masked, this sequence must be
343+
non-empty. Array ``vals`` must be one dimensional.
297344
"""
298345
if not isinstance(arr, dpt.usm_ndarray):
299346
raise TypeError(
@@ -345,13 +392,14 @@ def nonzero(arr):
345392
Return the indices of non-zero elements.
346393
347394
Returns a tuple of usm_ndarrays, one for each dimension
348-
of `arr`, containing the indices of the non-zero elements
349-
in that dimension. The values of `arr` are always tested in
395+
of ``arr``, containing the indices of the non-zero elements
396+
in that dimension. The values of ``arr`` are always tested in
350397
row-major, C-style order.
351398
352399
Args:
353400
arr (usm_ndarray):
354401
Input array, which has non-zero array rank.
402+
355403
Returns:
356404
Tuple[usm_ndarray, ...]:
357405
Indices of non-zero array elements.

0 commit comments

Comments
 (0)