Skip to content

Make ones/zeros/empty/full dtype handling more uniform #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class bool_(generic):
_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]}
_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]}


_aliases = {
"u1": uint8,
"i1": int8,
Expand Down Expand Up @@ -285,6 +286,11 @@ def name(self):
def type(self):
return self._scalar_type

@property
def kind(self):
# https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
return _torch_dtypes[self.torch_dtype].name[0]

@property
def typecode(self):
return self._scalar_type.typecode
Expand Down
14 changes: 7 additions & 7 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,6 @@ def wrapped(x, *args, **kwds):
###### dtype routines


def can_cast(from_, to, casting="safe"):
from_ = from_.dtype if isinstance(from_, ndarray) else _dtypes.dtype(from_)
to_ = to.dtype if isinstance(to, ndarray) else _dtypes.dtype(to)

return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)


def _extract_dtype(entry):
try:
dty = _dtypes.dtype(entry)
Expand All @@ -508,6 +501,13 @@ def _extract_dtype(entry):
return dty


def can_cast(from_, to, casting="safe"):
from_ = _extract_dtype(from_)
to_ = _extract_dtype(to)

return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)


def result_type(*arrays_and_dtypes):
dtypes = []

Expand Down
34 changes: 23 additions & 11 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,15 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
raise ValueError("Maximum allowed size exceeded")


@_decorators.dtype_to_torch
def empty(shape, dtype=float, order="C", *, like=None):
_util.subok_not_ok(like)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bit of a mismatch here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What mismatch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we have a function that's _util.like_not_implemented(like) or so? Why are we passing like to subok?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so it's the util name which is confusing. Conceptually, I'd say that like and subok are similar: one is specifically for ndarray subclasses, the other one is for a newer __array_wrap__ protocol (I think). Most functions which work for subclasses or the protocol, have both arguments, some have one or the other. We are not going to support either of them. The utility is literally two lines:
https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/main/torch_np/_detail/_util.py#L22

So if we go with utility functions I'd keep it as is, can rename it if really wanted. If we transform to the normalization approach of gh-32, these two would likely get their own type annotations.

So overall I'd rather not touch it now.

if order != "C":
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.empty(shape, dtype=torch_dtype))
if dtype is None:
dtype = _dtypes_impl.default_float_dtype
result = torch.empty(shape, dtype=dtype)
return asarray(result)


# NB: *_like function deliberately deviate from numpy: it has subok=True
Expand All @@ -310,17 +313,22 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
return result


@_decorators.dtype_to_torch
def full(shape, fill_value, dtype=None, order="C", *, like=None):
_util.subok_not_ok(like)
if order != "C":
raise NotImplementedError
if isinstance(fill_value, ndarray):
fill_value = fill_value.get()

fill_value = asarray(fill_value).get()
if dtype is None:
torch_dtype = asarray(fill_value).get().dtype
else:
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.full(shape, fill_value, dtype=torch_dtype))
dtype = fill_value.dtype
Comment on lines +322 to +324
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way to implement this would be having a function that given a tensor or a number returns its torch dtype. The reason here is that we don't want to wrap scalars into tensors if we don't need to. Scalars are easier to treat in a compiler than tensors, and often get extra optimisations (no memory allocations, no synchronisations when accessed if we are working on CUDA, you can specialise on them...)

Note that this applies in general, so we should be mindful about this going all across.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to #56, to not conflate a small cleanup of this PR with a rather large rework.


if not isinstance(shape, (tuple, list)):
shape = (shape,)

result = torch.full(shape, fill_value, dtype=dtype)

return asarray(result)


@asarray_replacer()
Expand All @@ -335,12 +343,15 @@ def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None):
return result


@_decorators.dtype_to_torch
def ones(shape, dtype=None, order="C", *, like=None):
_util.subok_not_ok(like)
if order != "C":
raise NotImplementedError
torch_dtype = _dtypes.torch_dtype_from(dtype)
return asarray(torch.ones(shape, dtype=torch_dtype))
if dtype is None:
dtype = _dtypes_impl.default_float_dtype
result = torch.ones(shape, dtype=dtype)
return asarray(result)


@asarray_replacer()
Expand All @@ -362,7 +373,8 @@ def zeros(shape, dtype=None, order="C", *, like=None):
raise NotImplementedError
if dtype is None:
dtype = _dtypes_impl.default_float_dtype
return asarray(torch.zeros(shape, dtype=dtype))
result = torch.zeros(shape, dtype=dtype)
return asarray(result)


@asarray_replacer()
Expand Down
Loading