-
Notifications
You must be signed in to change notification settings - Fork 4
Make ones/zeros/empty/full dtype handling more uniform #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5219e3e
402d055
c260360
4f1e66e
efc8325
89a8d54
1b41e14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -288,12 +288,15 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None): | |
raise ValueError("Maximum allowed size exceeded") | ||
|
||
|
||
@_decorators.dtype_to_torch | ||
def empty(shape, dtype=float, order="C", *, like=None): | ||
_util.subok_not_ok(like) | ||
if order != "C": | ||
raise NotImplementedError | ||
torch_dtype = _dtypes.torch_dtype_from(dtype) | ||
return asarray(torch.empty(shape, dtype=torch_dtype)) | ||
if dtype is None: | ||
dtype = _dtypes_impl.default_float_dtype | ||
result = torch.empty(shape, dtype=dtype) | ||
return asarray(result) | ||
|
||
|
||
# NB: *_like function deliberately deviate from numpy: it has subok=True | ||
|
@@ -310,17 +313,22 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None): | |
return result | ||
|
||
|
||
@_decorators.dtype_to_torch | ||
def full(shape, fill_value, dtype=None, order="C", *, like=None): | ||
_util.subok_not_ok(like) | ||
if order != "C": | ||
raise NotImplementedError | ||
if isinstance(fill_value, ndarray): | ||
fill_value = fill_value.get() | ||
|
||
fill_value = asarray(fill_value).get() | ||
if dtype is None: | ||
torch_dtype = asarray(fill_value).get().dtype | ||
else: | ||
torch_dtype = _dtypes.torch_dtype_from(dtype) | ||
return asarray(torch.full(shape, fill_value, dtype=torch_dtype)) | ||
dtype = fill_value.dtype | ||
Comment on lines
+322
to
+324
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A better way to implement this would be having a function that given a tensor or a number returns its torch dtype. The reason here is that we don't want to wrap scalars into tensors if we don't need to. Scalars are easier to treat in a compiler than tensors, and often get extra optimisations (no memory allocations, no synchronisations when accessed if we are working on CUDA, you can specialise on them...) Note that this applies in general, so we should be mindful about this going all across. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this to #56, to not conflate a small cleanup of this PR with a rather large rework. |
||
|
||
if not isinstance(shape, (tuple, list)): | ||
shape = (shape,) | ||
|
||
result = torch.full(shape, fill_value, dtype=dtype) | ||
|
||
return asarray(result) | ||
|
||
|
||
@asarray_replacer() | ||
|
@@ -335,12 +343,15 @@ def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None): | |
return result | ||
|
||
|
||
@_decorators.dtype_to_torch | ||
def ones(shape, dtype=None, order="C", *, like=None): | ||
_util.subok_not_ok(like) | ||
if order != "C": | ||
raise NotImplementedError | ||
torch_dtype = _dtypes.torch_dtype_from(dtype) | ||
return asarray(torch.ones(shape, dtype=torch_dtype)) | ||
if dtype is None: | ||
dtype = _dtypes_impl.default_float_dtype | ||
result = torch.ones(shape, dtype=dtype) | ||
return asarray(result) | ||
|
||
|
||
@asarray_replacer() | ||
|
@@ -362,7 +373,8 @@ def zeros(shape, dtype=None, order="C", *, like=None): | |
raise NotImplementedError | ||
if dtype is None: | ||
dtype = _dtypes_impl.default_float_dtype | ||
return asarray(torch.zeros(shape, dtype=dtype)) | ||
result = torch.zeros(shape, dtype=dtype) | ||
return asarray(result) | ||
|
||
|
||
@asarray_replacer() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a bit of a mismatch here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What mismatch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we have a function that's
_util.like_not_implemented(like)
or so? Why are we passinglike
tosubok
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so it's the util name which is confusing. Conceptually, I'd say that
like
andsubok
are similar: one is specifically for ndarray subclasses, the other one is for a newer__array_wrap__
protocol (I think). Most functions which work for subclasses or the protocol, have both arguments, some have one or the other. We are not going to support either of them. The utility is literally two lines:https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/main/torch_np/_detail/_util.py#L22
So if we go with utility functions I'd keep it as is, can rename it if really wanted. If we transform to the normalization approach of gh-32, these two would likely get their own type annotations.
So overall I'd rather not touch it now.