Skip to content

Commit 7a5b98c

Browse files
committed
Fix some review comments
1 parent 9803d45 commit 7a5b98c

File tree

1 file changed

+150
-121
lines changed

1 file changed

+150
-121
lines changed

RFC.md

Lines changed: 150 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,27 @@
1-
# Summary
1+
# A PyTorch - NumPy compatibility layer
22

3+
**Authors:**
4+
* @ev-br
5+
* @lezcano
6+
* @rgommers
7+
8+
## Summary
39
This RFC describes a proposal for a translation layer from NumPy into PyTorch.
410
In simple terms, this accounts for implementing most of NumPy's API (`ndarray`,
5-
the `np`, `np.linalg`, `np.fft` modules, etc) using `torch.Tensor` and PyTorch
6-
ops as backend.
11+
the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor`
12+
and PyTorch ops as backend.
713

8-
The this project has two main goals:
9-
1. Have a `torch.numpy` submodule, similar to `jax.numpy` that serves as a
10-
drop-in replacement for NumPy when imported as `import torch.numpy as np`.
11-
2. Have TorchDynamo understand and use this layer to be able to trace through
12-
NumPy programs as if they were written in PyTorch
1314

14-
Two corollaries of this work should be:
15-
1. Given NumPy code, one should be able to differentiate through it using
16-
PyTorch's autograd engine
17-
2. Given NumPy code, one should be able to execute it on CUDA
15+
The this project has a main goal as per the
16+
[initial design document](https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8):
17+
1. Make TorchDynamo understand NumPy calls
1818

1919
The work is being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/).
2020

21-
# The Translation Layer
2221

23-
In this section we discuss the ideas behind design and implementation of the
24-
translation layer from PyTorch to NumPy
22+
## Motivation
2523

26-
## The two expected uses
24+
### An introductory example
2725

2826
Let's start with some examples.
2927

@@ -37,7 +35,7 @@ z = np.dot(x, y)
3735
w = z.sum()
3836
```
3937

40-
By changing the first line to `import torch.numpy as np`, the semantics of the
38+
When we trace this program with the compat layer, the semantics of the
4139
program would stay the same, but the implementation would be equivalent to
4240

4341
```python
@@ -78,7 +76,7 @@ t_results[0] = result # store the result in a torch.Tensor
7876
```
7977

8078
This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements
81-
the `__array__` method. For it to work with the compatibility layer, we would
79+
the `__array__` method. For it to work manually with the compatibility layer, we would
8280
need to wrap and unwrap the inputs / outputs. This could be done modifying `fn`
8381
as
8482

@@ -90,13 +88,87 @@ def fn(x, y):
9088
return ret.tensor.numpy()
9189
```
9290

93-
Note that this wrapping / unwrapping process can be easily automated via a decorator.
94-
Even more, if a user wants to use PyTorch as a backend in a code that mixes
95-
PyTorch and NumPy, it will mostly be the case that it is because they want to
96-
trace through that code. In that setting, TorchDynamo will be able to
97-
automatically do the wrapping/unwrapping.
91+
This process would be done automatically by TorchDynamo, so we would simply need to write
92+
```python
93+
@compile
94+
def fn(x, y):
95+
return np.multiply(x, y).sum()
96+
```
97+
98+
### The observable behavior
99+
100+
The two main idea driving the design of this compatibility layer were the following:
101+
102+
1. The behavior of the layer should be as close to that of NumPy as possible
103+
2. The layer follows NumPy master
104+
105+
The following design decisions follow from these:
106+
107+
**Default dtypes**. One of the issues that most often user when moving their
108+
codebase from NumPy to JAX was the default dtype changing from `float64` to
109+
`float32`. So much so, that this is one noted as one of
110+
[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
111+
Following the spirit of making everything match NumPy by default, we choose the
112+
NumPy defaults whenever the `dtype` was not chosen in a factory function.
113+
114+
**TODO(Lezcano)**: I just realized that we do not have a clean way to change
115+
the default dtype of `torch_np` to those from PyTorch. We should implement
116+
that utility flag, similar to
117+
[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html).
118+
Perhaps call it `torch_np.use_torch_defaults()` and then add a way for users
119+
to be able to set their own int/float/complex defaults.
120+
**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also
121+
use them anywhere else -> Check
122+
123+
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks
124+
quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or
125+
`np.longdouble`. Upon closer inspection, one finds that it also has
126+
[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects.
127+
NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
128+
are NumPy's preferred return class for reductions and other operations that
129+
return just one element. NumPy scalars do not play particularly well with
130+
computations on devices like GPUs, as they live on CPU. Implementing NumPy
131+
scalars would mean that we need to synchronize after every `sum()` call, which
132+
is less-than-good. Instead, whenever a NumPy scalar would be returned, we will
133+
return a 0-D tensor, as PyTorch does.
134+
135+
**Type promotion**. Another not-so-well-known fact of NumPy's cast system is
136+
that it is data-dependent. Python scalars can be used in pretty much any NumPy
137+
operation, being able to call any operation that accepts a 0-D array with a
138+
Python scalar. If you provide an operation with a Python scalar, these will be
139+
casted to the smallest dtype that can represent them, and then they will
140+
participate in type promotion, allowing for some rather interesting behaviour
141+
```python
142+
>>> np.asarray([1], dtype=np.int8) + 127
143+
array([128], dtype=int8)
144+
>>> np.asarray([1], dtype=np.int8) + 128
145+
array([129], dtype=int16)
146+
```
147+
This dependent type promotion will be deprecated NumPy 2.0, and will be
148+
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
149+
As such, to be forward-looking and for simplicity, we chose to implement the
150+
type promotion behaviour proposed in NEP 50, which is much closer to that of
151+
Pytorch.
152+
153+
Note that the decision of going with NEP 50 complements the previous one of
154+
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
155+
participate in type promotion in NumPy (but will do in NumPy 2.0):
156+
```python
157+
int64_0d_array = np.array(1, dtype=np.int64)
158+
np.result_type(np.int8, int64_0d_array) == np.int8
159+
```
160+
161+
**Versioning**. It should be clear from the previous points that NumPy has a
162+
fair amount of questionable and legacy pain points. As such, we decided that
163+
rather than trying to fight these, we would declare that the compat layer
164+
follows the behavior of Numpy's master. Given the stability of NumPy's API and
165+
how battle-tested its main functions are, we do not expect this to become a big
166+
maintenance burden. If anything, it should make our lives easier, as some parts
167+
of NumPy will soon be simplified and we will not need to implement them, as
168+
described above.
98169

99-
## The `torch.numpy` module
170+
171+
## The `torch_np` module
100172

101173
The bulk of the work went into implementing a system that allows us to
102174
implement NumPy operations in terms of those of PyTorch. The main design goals
@@ -107,9 +179,9 @@ were
107179

108180
We say *most* of NumPy's API, because NumPy's API is not only massive, but also
109181
there are parts of it which cannot be implemented in PyTorch. For example,
110-
NumPy has support for arrays of strings, dates, and other `dtype`s that PyTorch
111-
does not consider. Negative strides are other example. We put together a list
112-
of things that are out of the scope of this project in the
182+
NumPy has support for arrays of string, datetime, structured and other dtypes.
183+
Negative strides are other example of a feature that is just out of the scope.
184+
We put together a list of things that are out of the scope of this project in the
113185
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
114186

115187
For the bulk of the functions, we started by prioritizing most common
@@ -124,6 +196,11 @@ The second point of preserving NumPy semantics as much as possible will be used
124196
in the sequel to discuss some points like the default dtypes that are used
125197
throughout the implementation.
126198

199+
**Visibility of the module** For simplicity, this RFC assumes that the
200+
`torch_np` module will not be public, as the decision for it to be made public
201+
was met with different opinions. We discuss these in the "Unresolved Questions"
202+
section.
203+
127204
### Annotation-based preprocessing
128205

129206
NumPy accepts virtually anything that smells like an array as input to its operators
@@ -138,7 +215,7 @@ array([1, 2, 3, 4, 5, 6])
138215

139216
To implement NumPy in terms of PyTorch, for any operation we would need to put
140217
the inputs into tensors, perform the operations, and then wrap the tensor into
141-
a `torch.numpy.ndarray` (more on this class later).
218+
a `torch_np.ndarray` (more on this class later).
142219

143220
To avoid all this code repetition, we implement the functions in two steps.
144221

@@ -173,16 +250,16 @@ gathering all the inputs at runtime and normalizing them according to their
173250
annotations.
174251

175252
We currently have four annotations (and small variations of them):
176-
- `ArrayLike`: The input can be a `torch.numpy.array`, a list of lists, a
253+
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a
177254
scalar, or anything that NumPy would accept. It returns a `torch.Tensor`.
178-
- `DTypeLike`: Takes a `torch.numpy` dtype and returns a PyTorch dtype.
255+
- `DTypeLike`: Takes a `torch_np` dtype and returns a PyTorch dtype.
179256
- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or
180257
an `ndarray`) and returns a tuple.
181-
- `OutArray`: Asserts that the input is a `torch.numpy.ndarray`. This is used
258+
- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used
182259
to implement the `out` arg.
183260

184261
Note that none of the code here makes use of NumPy. We are writing
185-
`torch.numpy.ndarray` above to make more explicit our intents, but there
262+
`torch_np.ndarray` above to make more explicit our intents, but there
186263
shouldn't be any ambiguity here.
187264

188265
**OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]`
@@ -213,114 +290,66 @@ class is rather simple. We simply register all the free functions as methods or
213290
dunder methods appropriately. We also forward the properties to the properties
214291
within the PyTorch tensor and we are done.
215292

216-
### DTypes
217-
218-
**Default dtypes**. One of the issues that most often user when moving their
219-
codebase from NumPy to JAX was the default dtype changing from `float64` to
220-
`float32`. So much so, that this is one noted as one of
221-
[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
222-
Following the spirit of making everything match NumPy by default, we choose the
223-
NumPy defaults whenever the `dtype` was not chosen in a factory function.
224-
225-
**TODO(Lezcano)**: I just realised that we do not have a clean way to change
226-
the default dtype of `torch.numpy` to those from PyTorch. We should implement
227-
that utility flag, similar to
228-
[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html).
229-
Perhaps call it `torch.numpy.use_torch_defaults()` and then add a way for users
230-
to be able to set their own int/float/complex defaults.
231-
**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also
232-
use them anywhere else -> Check
233-
234-
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks
235-
quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or
236-
`np.longdouble`. Upon closer inspection, one finds that it also has
237-
[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects.
238-
NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
239-
are NumPy's preferred return class for reductions and other operations that
240-
return just one element. NumPy scalars do not play particularly well with
241-
computations on devices like GPUs, as they live on CPU. Implementing NumPy
242-
scalars would mean that we need to synchronize after every `sum()` call, which
243-
is less-than-good. Instead, whenever a NumPy scalar would be returned, we will
244-
return a 0-D tensor, as PyTorch does.
245-
246-
**Type promotion**. Another not-so-well-known fact of NumPy's cast system is
247-
that it is data-dependent. Python scalars can be used in pretty much any NumPy
248-
operation, being able to call any operation that accepts a 0-D array with a
249-
Python scalar. If you provide an operation with a Python scalar, these will be
250-
casted to the smallest dtype that can represent them, and then they will
251-
participate in type promotion, allowing for some rather interesting behaviour
252-
```python
253-
>>> np.asarray([1], dtype=np.int8) + 127
254-
array([128], dtype=int8)
255-
>>> np.asarray([1], dtype=np.int8) + 128
256-
array([129], dtype=int16)
257-
```
258-
This dependent type promotion will be deprecated NumPy 2.0, and will be
259-
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
260-
As such, to be forward-looking and for simplicity, we chose to implement the
261-
type promotion behaviour proposed in NEP 50, which is much closer to that of
262-
Pytorch.
263-
264-
Note that the decision of going with NEP 50 complements the previous one of
265-
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
266-
participate in type promotion in NumPy (but will do in NumPy 2.0):
267-
```python
268-
int64_0d_array = np.array(1, dtype=np.int64)
269-
np.result_type(np.int8, int64_0d_array) == np.int8
270-
```
271-
272-
## Testing
293+
### Testing
273294

274295
The testing of the framework was done via ~~copying~~ vendoring tests from the
275296
NumPy test suit. Then, we would replace the NumPy imports for imports with
276-
`torch.numpy`. The failures on these tests were then triaged and discussed the
297+
`torch_np`. The failures on these tests were then triaged and discussed the
277298
priority of fixing each of them.
278299

279300
In the (near) future, we plan to get some real world examples and run them
280301
through the library, to test its coverage and correctness.
281302

282-
## Limitations
283-
284-
One of the known limitations of this approach is the efficiency in eager.
285-
Similar to PrimTorch, sometimes we needed to work around some limitations of
286-
PyTorch (e.g. support for some operations for `float16`) or some ways PyTorch
287-
deviates from NumPy by implementing things manually calling several `torch`
288-
operations. This, when executed in eager mode and, in particular, on CUDA
289-
devices, will result on a perf-hit. To alleviate this, we tried to dispatch
290-
NumPy functions to PyTorch functions with as few indirections as possible, to
291-
alleviate the number of kernels called when executed on eager mode.
303+
### Limitations
292304

293-
There are some known limitations. Some of them are tracked in the second part
294-
of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
305+
A number of known limitations are tracked in the second part of the
306+
[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
295307
There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86).
296308
When landing all this, we will create a comprehensive document with the differences
297-
between NumPy and `torch.numpy`.
309+
between NumPy and `torch_np`.
298310

299-
## Beyond NumPy
311+
### Beyond Plain NumPy
300312

301-
**CUDA**. The current implementation has just been implemented and tested on
302-
CPU. We expect CUDA coverage to be as good as the coverage we have with CPU
303-
matching CUDA. In the NumPy-only example in the introduction, given that no
304-
explicit `device` kwarg is used anywhere in this module, CUDA execution could
305-
be turned on via `with torch.device('cuda'):`. In the PyTorch+NumPy example, if
306-
the original tensors are on GPU, the whole execution should be performed on the
307-
GPU.
313+
**GPU**. The current implementation has just been implemented and tested on
314+
CPU. We expect GPU coverage to be as good as the coverage we have with CPU
315+
matching GPU. If the original tensors are on GPU, the whole execution should
316+
be performed on the GPU.
308317

309318
**TODO(Lezcano)**. We should probably test CUDA on the tests.
310319

311320
**Gradients**. We have not tested gradient tracking either as we are still to
312321
find some good examples on which to test it, but it should be a simple
313-
corollary of all this effort. In the PyTorch+NumPy scenario, if the original
314-
tensors fed into the function do have `requires_grad=True`, the tensors will
315-
track the gradients of the internal implementation and then the user could
316-
differentiate through the NumPy code. We do not have a way to turn the
317-
`requires_grad` flag in the all-NumPy case. Note that this is expected as this
318-
would require exposing all the autograd machinery from PyTorch into the API. If
319-
a user wants to compute gradients in their program, we expect them to wrap it
320-
in a function and apply the PyTorch-NumPy approach.
322+
corollary of all this effort. If the original tensors fed into the function do
323+
have `requires_grad=True`, the tensors will track the gradients of the internal
324+
implementation and then the user could differentiate through the NumPy code.
321325

322326
**TODO(Lezcano)**. Picking up simple NumPy programs from the internet would be good for these autograd tests.
323327

324-
# Bindings to TorchDyamo
328+
### Bindings to TorchDyamo
329+
330+
The bindings for NumPy at the TorchDynamo level are currently being developed at [#95849](https://github.com/pytorch/pytorch/pull/95849).
331+
332+
333+
## Unresolved Questions
334+
335+
A question was left open in the initial discussion. Should the module `torch_np` be publicly exposed as `torch.numpy` or not?
336+
337+
A few arguments in favor of making it public:
338+
* People could use it in their NumPy programs just by changing the import to
339+
`import torch.numpy as np`. This could be a selling point similar to JAX's
340+
`jax.numpy`, which could incentivize adoption.
341+
* People would not need to use the whole PyTorch 2.0 stack to start using
342+
PyTorch in their codebases
343+
* See [this experiment in scikit-learn](https://github.com/scikit-learn/scikit-learn/pull/25956)
344+
where they got a 7x speed-up on CPU on a layer just by using `torch.linalg`.
345+
* Since the layer is rather thin and in pure Python, if there are bugs,
346+
external contributors could easily help fixing them or extend the supported
347+
functionality.
325348

326-
**TODO(Lezcano)**: The PR is not there yet cf. [#95849](https://github.com/pytorch/pytorch/pull/95849).
349+
A few arguments against:
350+
* The compat introduces a number of type conversions that may produce somewhat
351+
slow code when used in eager mode.
352+
* [Note] Keeping this in mind, we tried to use in the implementations as few
353+
operators as possible, to make it reasonably fast in eager mode.
354+
* Exposing `torch.numpy` would create a less performant secondary entry point
355+
to many of the functions in PyTorch. This could be a trap for new users.

0 commit comments

Comments
 (0)