Skip to content

Commit 42f0070

Browse files
authored
Merge pull request #112 from Quansight-Labs/rfc
RFC
2 parents 96ac341 + 0a75be4 commit 42f0070

File tree

1 file changed

+380
-0
lines changed

1 file changed

+380
-0
lines changed

RFC.md

Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
# A PyTorch - NumPy compatibility layer
2+
3+
**Authors:**
4+
* @ev-br
5+
* @lezcano
6+
* @rgommers
7+
8+
## Summary
9+
This RFC describes a proposal for a translation layer from NumPy into PyTorch.
10+
In simple terms, this accounts for implementing most of NumPy's API (`ndarray`,
11+
the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor`
12+
and PyTorch ops as backend.
13+
14+
The main goal is: **make TorchDynamo understand NumPy calls**.
15+
This should enable an end user to combine code that uses the PyTorch API with
16+
code that uses the NumPy API, in a way that allows TorchDynamo to understand
17+
those function calls and build up an execution graph. To enable this, it is key
18+
that there is a translation layer from NumPy to PyTorch function calls, which
19+
TorchDynamo can use in order to build up its execution graph from PyTorch
20+
functions/primitives only. For niche functions in NumPy that don’t have a
21+
PyTorch equivalent, it’s okay to graph break and still call NumPy to execute
22+
the function call.
23+
24+
The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/).
25+
26+
27+
## Motivation
28+
29+
### Introductory examples
30+
31+
Consider the following snippet:
32+
```python
33+
import numpy as np
34+
35+
x = np.random.randn(3, 4)
36+
y = np.random.randn(4, 3)
37+
z = np.dot(x, y)
38+
w = z.sum()
39+
```
40+
41+
When we trace this program with the compat layer, the semantics of the
42+
program would stay the same, but the implementation would be equivalent to
43+
44+
```python
45+
import torch
46+
x = torch.randn(3, 4, dtype=torch.float64)
47+
y = torch.randn(4, 3, dtype=torch.float64)
48+
z = torch.matmul(x, y)
49+
w = z.sum()
50+
```
51+
52+
Here, we can already spot a couple differences between NumPy and PyTorch.
53+
The most obvious one is that the default dtype in NumPy is `float64` rather than
54+
`float32`. The less obvious is very sneakily hiding in the last line.
55+
56+
```python
57+
>>> type(w)
58+
<class 'numpy.float64'>
59+
```
60+
61+
Reductions and similar operations in NumPy return the infamous NumPy scalars.
62+
We'll discuss these and other NumPy quirks and how we dealt with them in the
63+
[design decision section](#design-decisions).
64+
65+
66+
Let's now have a look at a toy example of how this layer would be used.
67+
```python
68+
import torch
69+
import numpy as np
70+
t1 = torch.tensor([1, 3, 5])
71+
t2 = torch.exp(t)
72+
# Now say the user has some code lying around which uses NumPy:
73+
def fn(x, y):
74+
return np.multiply(x, y).sum()
75+
76+
result = fn(t1, t2)
77+
t_results = torch.empty(5, dtype=torch.float64)
78+
t_results[0] = result # store the result in a torch.Tensor
79+
```
80+
81+
Note that this code mixing NumPy and PyTorch already works in eager mode with
82+
CPU tensors, as `torch.Tensor` implements the `__array__` method. Now, the
83+
compatibility layer allows us to trace through it. In order to do that, there
84+
would be no necessary changes, other than simply ask `torch.compile` to trace
85+
through it:
86+
87+
```python
88+
@compile
89+
def fn(x, y):
90+
return np.multiply(x, y).sum()
91+
```
92+
93+
Then, TorchDynamo would will cast `x` and `y` to our internal implementation of `ndarray`,
94+
and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch`
95+
functions, effectively turning this function into a pure PyTorch function.
96+
97+
### Design decisions
98+
99+
The main ideas driving the design of this compatibility layer are the following:
100+
101+
1. The goal is to transform valid NumPy and mixed PyTorch-NumPy programs into
102+
their equivalent PyTorch-only execution.
103+
2. The behavior of the layer should be as close to that of NumPy as possible
104+
3. The layer follows the most recent NumPy release
105+
106+
The following design decisions follow from these:
107+
108+
**A superset of NumPy**. NumPy has a number of well-known edge-cases (as does
109+
PyTorch, like spotty support for `float16` on CPU and `complex32` in general).
110+
The decision to translate only valid NumPy programs, often allows us to
111+
implement a superset of the functionality of NumPy with more predictable and
112+
consistent behavior than NumPy itself has.
113+
114+
**Exceptions may be different**. We avoid entirely modelling the exception
115+
system in NumPy. As seen in the implementation of PrimTorch, modelling the
116+
error cases of a given system is terribly difficult. We avoid this altogether
117+
and we choose not to offer any guarantee here.
118+
119+
**Default dtypes**. One of the most common issues that bites people when migrating their
120+
codebases from NumPy to JAX is the default dtype changing from `float64` to
121+
`float32`. So much so that this is noted as one of
122+
[JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
123+
Following the spirit of making everything match NumPy by default, we choose the
124+
NumPy default dtype whenever the `dtype` was not made explicit in a factory function.
125+
We also provide a function `set_default_dtype` that allows to change this behavior
126+
dynamically.
127+
128+
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks
129+
like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`.
130+
Upon closer inspection, one finds that it also has
131+
[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects.
132+
NumPy scalars are similar to Python scalars but with a fixed precision and
133+
array-like methods attached. NumPy scalars are NumPy's preferred return class
134+
for reductions and other operations that return just one element.
135+
NumPy scalars do not play particularly well with
136+
computations on devices like GPUs, as they live on CPU. Implementing NumPy
137+
scalars would mean that we need to synchronize after every `sum()` call, which
138+
would be terrible performance-wise. In this implementation, we choose to represent
139+
NumPy scalars as 0-D arrays. This may cause small divergences in some cases. For example,
140+
consider the following NumPy behavior:
141+
142+
```python
143+
>>> np.int32(2) * [1, 2, 3] # scalar decays to a python int
144+
[1, 2, 3, 1, 2, 3]
145+
146+
>>> np.asarray(2) * [1, 2, 3] # zero-dim array is an array-like
147+
array([2, 4, 6])
148+
```
149+
150+
We don't expect these to pose a big issue in practice. Note that in the
151+
proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`.
152+
In general, we try to avoid unnecessary graph breaks whenever we can. For
153+
example, we may choose to return a tensor of shape `(2, *)` rather than a list
154+
of pairs, to avoid a graph break.
155+
156+
**Type promotion**. Another not-so-well-known fact of NumPy's dtype system and casting rules
157+
is that it is data-dependent. Python scalars can be used in pretty much any NumPy
158+
operation, being able to call any operation that accepts a 0-D array with a
159+
Python scalar. If you provide an operation with a Python scalar, these will be
160+
cast to the smallest dtype they can be represented in, and only then will they
161+
participate in type promotion. This allows for for some rather interesting behaviour
162+
```python
163+
>>> np.asarray([1], dtype=np.int8) + 127
164+
array([128], dtype=int8)
165+
>>> np.asarray([1], dtype=np.int8) + 128
166+
array([129], dtype=int16)
167+
```
168+
This data-dependent type promotion will be removed in NumPy 2.0 (planned for Dec'23), and will be
169+
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html)
170+
(already implemented in NumPy, it needs to be enabled via a private global switch now).
171+
For simplicity and to be forward-looking, we chose to implement the
172+
type promotion behaviour proposed in NEP 50, which is much closer to that of
173+
PyTorch.
174+
175+
Note that the decision of going with NEP 50 complements the previous one of
176+
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
177+
participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50):
178+
```python
179+
int64_0d_array = np.array(1, dtype=np.int64)
180+
np.result_type(np.int8, int64_0d_array) == np.int8
181+
```
182+
183+
**Versioning**. It should be clear from the previous points that NumPy has a
184+
fair amount of questionable behavior and legacy pain points. It is for this reason that
185+
we decided that rather than fighting these, we would declare that the compat
186+
layer follows the behavior of NumPy's most recent release (even, in some cases,
187+
of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its
188+
main functions are, we do not expect this to become a big maintenance burden.
189+
If anything, it should make our lives easier, as some parts of NumPy will soon
190+
be simplified, saving us the pain of having to implement all the pre-existing
191+
corner-cases.
192+
193+
**Randomness**. PyTorch and NumPy use different random number generation methods.
194+
In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html)
195+
with a `Generator` object which has sampling methods on it. The current compat
196+
layer does not implement this new API, as the default bit generator in NumPy is
197+
`PCG64`, while on PyTorch we use `MT19937` on CPU and `Philox` on non-CPU devices.
198+
From this, it follows that this API will not give any reproducibility
199+
guarantees when it comes to randomness.
200+
201+
202+
## The `torch_np` module
203+
204+
The bulk of the work went into implementing a system that allows us to
205+
implement NumPy operations in terms of those of PyTorch. The main design goals
206+
here were
207+
208+
1. Implement *most* of NumPy's API
209+
2. Preserve NumPy semantics as much as possible
210+
211+
We say *most* of NumPy's API, because NumPy's API is not only massive, but also
212+
there are parts of it which cannot be implemented in PyTorch. For example,
213+
NumPy has support for arrays of string, datetime, structured and other dtypes.
214+
Negative strides are another example of a feature that is not supported in PyTorch.
215+
We put together a list of things that are out of the scope of this project in the
216+
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
217+
218+
For the bulk of the functions, we started by prioritizing the most common
219+
operations. Then, when bringing tests from the NumPy test suite, we triaged
220+
and prioritized how important it was to fix each failure we found. Doing this
221+
iteratively, we ended up with a small list of differences between the NumPy and
222+
PyTorch APIs, which we prioritized by hand. That list and the prioritization
223+
discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87).
224+
225+
**Visibility of the module** For simplicity, this RFC assumes that the
226+
`torch_np` module will not be public, as the initial suggestion for it to be
227+
made public was met with mixed opinions. This topic can be revisited in the
228+
future if desired.
229+
We discuss these in the section [unresolved questions](#unresolved-questions).
230+
231+
### Annotation-based preprocessing
232+
233+
NumPy accepts virtually anything that smells like an array as an input.
234+
235+
```python
236+
>>> np.add(1, 3)
237+
4
238+
>>> np.add([1., 2., 3.], 5)
239+
array([6., 7., 8.])
240+
>>> np.concatenate([[1, 2, 3], [4, 5, 6]])
241+
array([1, 2, 3, 4, 5, 6])
242+
```
243+
244+
NumPy calls all these objects `array_like` objects.
245+
To implement NumPy in terms of PyTorch, for any operation we would need to map
246+
inputs into tensors, perform the operations, and then wrap the tensor into
247+
a `torch_np.ndarray` (more on this class later).
248+
249+
To avoid all this code repetition, we implement the functions in two steps.
250+
251+
First, we implement functions with the NumPy signature, but assuming that in
252+
place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc)
253+
they simply accept `torch.Tensor` and PyTorch-land objects and return
254+
`torch.Tensor`s. For example, we would implement `np.diag` as
255+
256+
```python
257+
def diag(v, k=0):
258+
return torch.diag(v, k)
259+
```
260+
261+
In this layer, if a NumPy function is composite (calls other NumPy functions
262+
internally), we can simply vendor its implementation, and have it call our
263+
PyTorch-land implementations of these functions. In other words, at this level,
264+
functions are composable, as they are simply regular PyTorch functions.
265+
All these implementations are internal, and are not meant to be seen or used
266+
by the end user.
267+
268+
The second step is then done via type annotations and a decorator. Each type
269+
annotation has an associated function from NumPy-land into PyTorch-land. This
270+
function converts the set of inputs accepted by NumPy for that argument into a
271+
PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example,
272+
for `np.diag` we would write
273+
274+
```python
275+
def diag(v: ArrayLike, k=0):
276+
return torch.diag(v, k)
277+
```
278+
279+
Then, we wrap these Python-land functions with a `normalizer` decorator and
280+
expose them in the `torch_np` module. This decorator is in charge of gathering
281+
all the inputs at runtime and normalizing them (i.e., converting `torch_np`
282+
objects to PyTorch counterparts) according to their annotations.
283+
284+
We currently have four annotations (and small variations of them):
285+
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a
286+
scalar, or anything that NumPy would accept. It returns a `torch.Tensor`.
287+
- `DTypeLike`: Takes a `torch_np` dtype, and any other object that Numpy dtypes
288+
accept (strings, typecodes...) and returns a PyTorch dtype.
289+
- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or
290+
an `ndarray`) and returns a tuple.
291+
- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used
292+
to implement the `out` keyword.
293+
294+
Note that none of the code in this implementation makes use of NumPy. We are
295+
writing `torch_np.ndarray` above to make more explicit our intent, but there
296+
shouldn't be any ambiguity.
297+
298+
**Implementing `out`**: In PyTorch, the `out` kwarg is a keyword-only argument.
299+
It is for this reason that, in PrimTorch, we were able to implement it as [a
300+
decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282).
301+
This is not the case in NumPy. In NumPy, `out` can be used both as a positional
302+
and a keyword argument, and is often interleaved with other parameters. This is
303+
the reason why we use the `OutArray` annotation to mark these. We then
304+
implement the `out` semantics in the `@normalizer` wrapper in a generic way.
305+
306+
**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two
307+
sets of functions that are particularly regular. For these functions, we
308+
implement support for their arguments in a generic way as a preprocessing or
309+
postprocessing step.
310+
311+
**The `ndarray` class** Once we have all the free functions implemented as
312+
functions from `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the
313+
methods from the `ndarray` class is rather simple. We simply register all the
314+
free functions as methods or dunder methods appropriately. We also forward the
315+
properties of `ndarray to the corresponding properties of `torch.Tensor` and we
316+
are done. This creates a circular dependency which we break with a local
317+
import.
318+
319+
### Testing
320+
321+
The testing of the framework was done via ~~copying~~ vendoring tests from the
322+
NumPy test suite. Then, we would replace the NumPy imports with `torch_np`
323+
imports. The failures on these tests were then triaged, and either fixed or marked
324+
`xfail` depending on our assessment of the priority of implementing a fix.
325+
326+
In the end, to have a last check that this tool was sound, we pulled five
327+
examples of NumPy code from different sources and ran it with this library (eager mode execution).
328+
We were able to run the five examples successfully with close to no code changes.
329+
You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop).
330+
331+
### Limitations
332+
333+
A number of known limitations are tracked in the second part of the
334+
[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
335+
When landing this RFC, we will create a comprehensive document with the differences
336+
between NumPy and `torch_np`.
337+
338+
### Beyond plain NumPy
339+
340+
**GPU**. The current implementation so far only been implemented and tested on
341+
CPU. We expect GPU coverage to be as good as the coverage we have with CPU-GPU
342+
matching tests in the PyTorch test suite. If the original tensors are on GPU,
343+
the execution should be performed fully on GPU.
344+
345+
**Gradients**. We have not tested gradient tracking either as we are still to
346+
find some good examples on which to test it, but it should be a simple
347+
corollary of all this effort. If the original tensors fed into a function
348+
have `requires_grad=True`, the tensors will track the gradients of the internal
349+
implementation and then the user can differentiate through their NumPy code.
350+
351+
### Bindings to TorchDynamo
352+
353+
The bindings for NumPy at the TorchDynamo level are currently being developed in
354+
[pytorch#95849](https://github.com/pytorch/pytorch/pull/95849).
355+
356+
357+
## Unresolved questions
358+
359+
A question was left open in the initial discussion. Should the module
360+
`torch_np` be publicly exposed as `torch.numpy` or not?
361+
362+
A few arguments in favor of making it public:
363+
* People could use it in their NumPy programs just by changing the import to
364+
`import torch.numpy as np`. This could be a selling point similar to JAX's
365+
`jax.numpy`, which could incentivize adoption.
366+
* People would not need to use the whole PyTorch 2.0 stack to start using
367+
PyTorch in their codebases
368+
* See [this experiment in scikit-learn](https://github.com/scikit-learn/scikit-learn/pull/25956)
369+
where they got a 7x speed-up on CPU on a layer just by using `torch.linalg`.
370+
* Since the layer is rather thin and in pure Python, if there are bugs,
371+
external contributors could easily help fixing them or extend the supported
372+
functionality.
373+
374+
A few arguments against:
375+
* The compat introduces a number of type conversions that may produce somewhat
376+
slow code when used in eager mode.
377+
* [Note] Keeping this in mind, we tried to use as few operators as possible,
378+
in the implementation, to make it reasonably fast in eager mode.
379+
* Exposing `torch.numpy` would create a less performant secondary entry point
380+
to many of the functions in PyTorch. This could be a trap for new users.

0 commit comments

Comments
 (0)