Skip to content

Commit 69cc93b

Browse files
committed
fix astype bug
1 parent 762a03c commit 69cc93b

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

array_api_compat/common/_aliases.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,16 @@ def _asarray(
327327
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
328328
if dtype is not None and obj.dtype != dtype:
329329
copy = True
330-
# Dask arrays are immutable, so copy doesn't do anything
331-
if copy in COPY_TRUE and namespace != "dask.array":
332-
return xp.array(obj, copy=True, dtype=dtype)
330+
if copy in COPY_TRUE:
331+
copy_kwargs = {}
332+
if namespace != "dask.array":
333+
copy_kwargs["copy"] = True
334+
else:
335+
# No copy kw in dask.asarray so we go thorugh np.asarray first
336+
# (like dask also does) but copy after
337+
import numpy as np
338+
obj = np.asarray(obj).copy()
339+
return xp.array(obj, dtype=dtype, **copy_kwargs)
333340
return obj
334341

335342
return xp.asarray(obj, dtype=dtype, **kwargs)

0 commit comments

Comments
 (0)