-
Notifications
You must be signed in to change notification settings - Fork 4
RFC #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC #112
Changes from 4 commits
4e9d230
9803d45
7a5b98c
e84e635
e3c492b
ac07b1f
575b4a7
2d86a3f
b85a999
44ee780
332a2d0
0a75be4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,355 @@ | ||
# A PyTorch - NumPy compatibility layer | ||
|
||
**Authors:** | ||
* @ev-br | ||
* @lezcano | ||
* @rgommers | ||
|
||
## Summary | ||
This RFC describes a proposal for a translation layer from NumPy into PyTorch. | ||
In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, | ||
the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor` | ||
and PyTorch ops as backend. | ||
|
||
|
||
The this project has a main goal as per the | ||
[initial design document](https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8): | ||
1. Make TorchDynamo understand NumPy calls | ||
|
||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). | ||
|
||
|
||
## Motivation | ||
|
||
### Introductory examples | ||
|
||
Consider the following snippet: | ||
```python | ||
import numpy as np | ||
|
||
x = np.random.randn(3, 4) | ||
y = np.random.randn(4, 3) | ||
z = np.dot(x, y) | ||
w = z.sum() | ||
``` | ||
|
||
When we trace this program with the compat layer, the semantics of the | ||
program would stay the same, but the implementation would be equivalent to | ||
|
||
```python | ||
import torch | ||
x = torch.randn(3, 4, dtype=torch.float64) | ||
y = torch.randn(4, 3, dtype=torch.float64) | ||
z = torch.matmul(x, y) | ||
w = z.sum() | ||
``` | ||
|
||
Here, we can already spot a couple differences between NumPy and PyTorch. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The most obvious one is that the default dtype in NumPy is `float64` rather than | ||
`float32`. The less obvious is very sneakily hiding in the last line. | ||
|
||
```python | ||
>>> type(w) | ||
<class 'numpy.float64'> | ||
``` | ||
|
||
Reductions and similar operations in NumPy return the infamous NumPy scalars. | ||
We'll discuss these and other NumPy quirks and how we dealt with them in the | ||
[design decision section](#design-decisions). | ||
|
||
|
||
Let's now have a look at a toy example of how this layer would be used. | ||
```python | ||
import torch | ||
import numpy as np | ||
t1 = torch.tensor([1, 3, 5]) | ||
t2 = torch.exp(t) | ||
# Now say the user has some code lying around which uses NumPy: | ||
def fn(x, y): | ||
return np.multiply(x, y).sum() | ||
|
||
result = fn(t1, t2) | ||
t_results = torch.empty(5, dtype=torch.float64) | ||
t_results[0] = result # store the result in a torch.Tensor | ||
``` | ||
|
||
Note that this code mixing NumPy and PyTorch already works, as `torch.Tensor` | ||
implements the `__array__` method. Now, the compatibility layer allows us to | ||
trace through it. In order to do that, there would be no necessary changes, | ||
other than simply ask `torch.compile` to trace through it: | ||
|
||
```python | ||
@compile | ||
def fn(x, y): | ||
return np.multiply(x, y).sum() | ||
``` | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
### Design decisions | ||
|
||
The two main ideas driving the design of this compatibility layer are the following: | ||
|
||
1. The behavior of the layer should be as close to that of NumPy as possible | ||
2. The layer follows NumPy master | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The following design decisions follow from these: | ||
|
||
**Default dtypes**. One of the most common issues that bites people when migrating their | ||
codebases from NumPy to JAX is the default dtype changing from `float64` to | ||
`float32`. So much so that this is noted as one of | ||
[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). | ||
Following the spirit of making everything match NumPy by default, we choose the | ||
NumPy defaults whenever the `dtype` was not made explicit in a factory function. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**TODO(Lezcano)**: I just realized that we do not have a clean way to change | ||
the default dtype of `torch_np` to those from PyTorch. We should implement | ||
that utility flag, similar to | ||
[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). | ||
Perhaps call it `torch_np.use_torch_defaults()` and then add a way for users | ||
to be able to set their own int/float/complex defaults. | ||
**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also | ||
use them anywhere else -> Check | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks | ||
like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. | ||
Upon closer inspection, one finds that it also has | ||
[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. | ||
NumPy scalars are similar to Python scalars but with a set width. NumPy scalars | ||
are NumPy's preferred return class for reductions and other operations that | ||
return just one element. NumPy scalars do not play particularly well with | ||
computations on devices like GPUs, as they live on CPU. Implementing NumPy | ||
scalars would mean that we need to synchronize after every `sum()` call, which | ||
would be terrible performance-wise. In this implementation, we choose to represent | ||
NumPy scalars as 0-D arrays. This may cause small divergences in some cases like | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
>>> np.int32(2) * [1, 2, 3] # scalar decays to a python int | ||
[1, 2, 3, 1, 2, 3] | ||
|
||
>>> np.asarray(2) * [1, 2, 3] # zero-dim array is an array-like | ||
array([2, 4, 6]) | ||
``` | ||
|
||
but we don't expect these to pose a big issue in practice. Note that in this | ||
implementation `torch_np.int32(2)` would return the same as `torch_np.asarray(2)`. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**Type promotion**. Another not-so-well-known fact of NumPy's cast system is | ||
that it is data-dependent. Python scalars can be used in pretty much any NumPy | ||
operation, being able to call any operation that accepts a 0-D array with a | ||
Python scalar. If you provide an operation with a Python scalar, these will be | ||
casted to the smallest dtype they can be represented in, and then, they will | ||
participate in type promotion. This allows for for some rather interesting behaviour | ||
```python | ||
>>> np.asarray([1], dtype=np.int8) + 127 | ||
array([128], dtype=int8) | ||
>>> np.asarray([1], dtype=np.int8) + 128 | ||
array([129], dtype=int16) | ||
``` | ||
This dependent type promotion will be deprecated NumPy 2.0, and will be | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). | ||
For simplicity and to be forward-looking, we chose to implement the | ||
type promotion behaviour proposed in NEP 50, which is much closer to that of | ||
Pytorch. | ||
|
||
Note that the decision of going with NEP 50 complements the previous one of | ||
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not | ||
participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50): | ||
```python | ||
int64_0d_array = np.array(1, dtype=np.int64) | ||
np.result_type(np.int8, int64_0d_array) == np.int8 | ||
``` | ||
|
||
**Versioning**. It should be clear from the previous points that NumPy has a | ||
fair amount of questionable and legacy pain points. It is for this reason that | ||
we decided that rather than fighting these, we would declare that the compat | ||
layer follows the behavior of Numpy's master (even, in some cases, of NumPy | ||
2.0). Given the stability of NumPy's API and how battle-tested its main | ||
functions are, we do not expect this to become a big maintenance burden. If | ||
anything, it should make our lives easier, as some parts of NumPy will soon be | ||
simplified, saving us the pain of having to implement all the pre-existing | ||
corner-cases. | ||
|
||
For reference NumPy 2.0 is expected to land at the end of this year. | ||
|
||
|
||
## The `torch_np` module | ||
|
||
The bulk of the work went into implementing a system that allows us to | ||
implement NumPy operations in terms of those of PyTorch. The main design goals | ||
here were | ||
|
||
1. Implement *most* of NumPy's API | ||
2. Preserve NumPy semantics as much as possible | ||
|
||
We say *most* of NumPy's API, because NumPy's API is not only massive, but also | ||
there are parts of it which cannot be implemented in PyTorch. For example, | ||
NumPy has support for arrays of string, datetime, structured and other dtypes. | ||
Negative strides are other example of a feature that is just not supported in PyTorch. | ||
We put together a list of things that are out of the scope of this project in the | ||
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). | ||
|
||
For the bulk of the functions, we started by prioritizing the most common | ||
operations. Then, when bringing tests from the NumPy test suit, we would triage | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and prioritize how important was to fix each failure we found. Iterating this | ||
process, we ended up with a small list of differences between the NumPy and the | ||
PyTorch API which we prioritized by hand. That list and the prioritization | ||
discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). | ||
|
||
**Visibility of the module** For simplicity, this RFC assumes that the | ||
`torch_np` module will not be public, as the decision for it to be made public | ||
was met with different opinions. | ||
We discuss these in the section [unresolved questions](#unresolved-questions). | ||
|
||
### Annotation-based preprocessing | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
NumPy accepts virtually anything that smells like an array as an input | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
>>> np.add(1, 3) | ||
4 | ||
>>> np.add([1., 2., 3.], 5) | ||
array([6., 7., 8.]) | ||
>>> np.concatenate([[1, 2, 3], [4, 5, 6]]) | ||
array([1, 2, 3, 4, 5, 6]) | ||
``` | ||
|
||
To implement NumPy in terms of PyTorch, for any operation we would need to map | ||
inputs into tensors, perform the operations, and then wrap the tensor into | ||
a `torch_np.ndarray` (more on this class later). | ||
|
||
To avoid all this code repetition, we implement the functions in two steps. | ||
|
||
First, we implement functions with the NumPy signature, but assuming that in | ||
place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) | ||
they simply accept `torch.Tensor` and PyTorch-land objects and return | ||
`torch.Tensor`s. For example, we would implement `np.diag` as | ||
|
||
```python | ||
def diag(v, k=0): | ||
return torch.diag(v, k) | ||
``` | ||
|
||
In this layer, if a NumPy function is composite (calls other NumPy functions | ||
internally), we can simply vendor its implementation, and have it call our | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
|
||
PyTorch-land implementations of these functions. In other words, at this level, | ||
functions are composable, as they are simply regular PyTorch functions. | ||
All these implementations are internal, and are not meant to be seen or used | ||
by the final user. | ||
|
||
The second step is then done via type annotations and a decorator. Each type | ||
annotation has an associated function from NumPy-land into PyTorch-land. This | ||
function converts the set of inputs accepted by NumPy for that argument into a | ||
PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example, | ||
for `np.diag` we would write | ||
|
||
```python | ||
def diag(v: ArrayLike, k=0): | ||
return torch.diag(v, k) | ||
``` | ||
|
||
Then, we wrap these Python-land functions with a `normalizer` decorator and | ||
expose them in the `torch_np` module. This decorator is in charge of gathering | ||
all the inputs at runtime and normalizing them according to their annotations. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
We currently have four annotations (and small variations of them): | ||
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a | ||
scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. | ||
- `DTypeLike`: Takes a `torch_np` dtype and returns a PyTorch dtype. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or | ||
an `ndarray`) and returns a tuple. | ||
- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used | ||
to implement the `out` arg. | ||
|
||
Note that none of the code in this implementation makes use of NumPy. We are | ||
writing `torch_np.ndarray` above to make more explicit our intents, but there | ||
shouldn't be any ambiguity. | ||
|
||
**OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that | ||
are not being implemented? We could then assert that either that parameter has | ||
not been provided, and if it has, it has the same value as the default. The | ||
goal here would be to either use all the args of a function in its | ||
implementation, or mark explicitly those that we don't use. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a | ||
keyword-only argument. It is for this reason that, in PrimTorch, we were able | ||
to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). | ||
This is not the case in NumPy. In NumPy `out` is a positional arg that is often | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's going on my list of things to fix for NumPy 2.0:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what I meant two weeks ago, but this isn't true: >>> x = np.ones(3)
>>> np.sin(x, x)
array([0.84147098, 0.84147098, 0.84147098])
>>> x = np.ones(3)
>>> np.sin(x, out=x)
array([0.84147098, 0.84147098, 0.84147098]) I'll change this to "can be used both as a positional and a keyword argument". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, I meant "In NumPy, |
||
interleaved with other parameters. This is the reason why we use the `OutArray` | ||
annotation to mark these. We then implement the `out` semantics in the `@normalizer` | ||
wrapper in a generic way. | ||
|
||
**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two | ||
sets of functions that are particularly regular. For these functions, we | ||
implement their args in a generic way as a preprocessing or postprocessing. | ||
|
||
**The ndarray class** Once we have all the free functions implemented as | ||
functions form `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the | ||
methods from the `ndarray` class is rather simple. We simply register all the | ||
free functions as methods or dunder methods appropriately. We also forward the | ||
properties to the properties within the PyTorch tensor and we are done. | ||
This creates a circular dependency which we break with a local import. | ||
|
||
### Testing | ||
|
||
The testing of the framework was done via ~~copying~~ vendoring tests from the | ||
NumPy test suit. Then, we would replace the NumPy imports with `torch_np` | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
imports. The failures on these tests were then triaged and discussed the | ||
priority of fixing each of them. | ||
|
||
In the (near) future, we plan to get some real world examples and run them | ||
through the library, to test its coverage and correctness. | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
### Limitations | ||
|
||
A number of known limitations are tracked in the second part of the | ||
[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). | ||
There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
When landing this RFC, we will create a comprehensive document with the differences | ||
between NumPy and `torch_np`. | ||
|
||
### Beyond Plain NumPy | ||
|
||
**GPU**. The current implementation has just been implemented and tested on | ||
CPU. We expect GPU coverage to be as good as the coverage we have with CPU | ||
matching GPU. If the original tensors are on GPU, the whole execution should | ||
be performed on the GPU. | ||
|
||
**TODO(Lezcano)**. We should probably test CUDA on the tests. | ||
|
||
**Gradients**. We have not tested gradient tracking either as we are still to | ||
find some good examples on which to test it, but it should be a simple | ||
corollary of all this effort. If the original tensors fed into the function do | ||
have `requires_grad=True`, the tensors will track the gradients of the internal | ||
implementation and then the user could differentiate through the NumPy code. | ||
|
||
**TODO(Lezcano)**. Picking up simple NumPy programs from the internet would be good for these autograd tests. | ||
|
||
### Bindings to TorchDyamo | ||
|
||
The bindings for NumPy at the TorchDynamo level are currently being developed at [#95849](https://github.com/pytorch/pytorch/pull/95849). | ||
|
||
|
||
## Unresolved Questions | ||
|
||
A question was left open in the initial discussion. Should the module | ||
`torch_np` be publicly exposed as `torch.numpy` or not? | ||
|
||
A few arguments in favor of making it public: | ||
* People could use it in their NumPy programs just by changing the import to | ||
`import torch.numpy as np`. This could be a selling point similar to JAX's | ||
`jax.numpy`, which could incentivize adoption. | ||
* People would not need to use the whole PyTorch 2.0 stack to start using | ||
PyTorch in their codebases | ||
* See [this experiment in scikit-learn](https://github.com/scikit-learn/scikit-learn/pull/25956) | ||
where they got a 7x speed-up on CPU on a layer just by using `torch.linalg`. | ||
* Since the layer is rather thin and in pure Python, if there are bugs, | ||
external contributors could easily help fixing them or extend the supported | ||
functionality. | ||
|
||
A few arguments against: | ||
* The compat introduces a number of type conversions that may produce somewhat | ||
slow code when used in eager mode. | ||
* [Note] Keeping this in mind, we tried to use in the implementations as few | ||
operators as possible, to make it reasonably fast in eager mode. | ||
* Exposing `torch.numpy` would create a less performant secondary entry point | ||
to many of the functions in PyTorch. This could be a trap for new users. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The access is gated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to fix that by upstreaming the relevant content here. Best for this RFC to be standalone.