1
- # Summary
1
+ # A PyTorch - NumPy compatibility layer
2
2
3
+ ** Authors:**
4
+ * @ev-br
5
+ * @lezcano
6
+ * @rgommers
7
+
8
+ ## Summary
3
9
This RFC describes a proposal for a translation layer from NumPy into PyTorch.
4
10
In simple terms, this accounts for implementing most of NumPy's API (` ndarray ` ,
5
- the ` np ` , ` np .linalg` , ` np .fft` modules, etc) using ` torch.Tensor ` and PyTorch
6
- ops as backend.
11
+ the ` numpy ` , ` numpy .linalg` , ` numpy .fft` modules, etc) using ` torch.Tensor `
12
+ and PyTorch ops as backend.
7
13
8
- The this project has two main goals:
9
- 1 . Have a ` torch.numpy ` submodule, similar to ` jax.numpy ` that serves as a
10
- drop-in replacement for NumPy when imported as ` import torch.numpy as np ` .
11
- 2 . Have TorchDynamo understand and use this layer to be able to trace through
12
- NumPy programs as if they were written in PyTorch
13
14
14
- Two corollaries of this work should be:
15
- 1 . Given NumPy code, one should be able to differentiate through it using
16
- PyTorch's autograd engine
17
- 2 . Given NumPy code, one should be able to execute it on CUDA
15
+ The this project has a main goal as per the
16
+ [ initial design document] ( https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8 ) :
17
+ 1 . Make TorchDynamo understand NumPy calls
18
18
19
19
The work is being done at [ numpy_pytorch_interop] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/ ) .
20
20
21
- # The Translation Layer
22
21
23
- In this section we discuss the ideas behind design and implementation of the
24
- translation layer from PyTorch to NumPy
22
+ ## Motivation
25
23
26
- ## The two expected uses
24
+ ### An introductory example
27
25
28
26
Let's start with some examples.
29
27
@@ -37,7 +35,7 @@ z = np.dot(x, y)
37
35
w = z.sum()
38
36
```
39
37
40
- By changing the first line to ` import torch.numpy as np ` , the semantics of the
38
+ When we trace this program with the compat layer , the semantics of the
41
39
program would stay the same, but the implementation would be equivalent to
42
40
43
41
``` python
@@ -78,7 +76,7 @@ t_results[0] = result # store the result in a torch.Tensor
78
76
```
79
77
80
78
This code mixing NumPy and PyTorch already works, as ` torch.Tensor ` implements
81
- the ` __array__ ` method. For it to work with the compatibility layer, we would
79
+ the ` __array__ ` method. For it to work manually with the compatibility layer, we would
82
80
need to wrap and unwrap the inputs / outputs. This could be done modifying ` fn `
83
81
as
84
82
@@ -90,13 +88,87 @@ def fn(x, y):
90
88
return ret.tensor.numpy()
91
89
```
92
90
93
- Note that this wrapping / unwrapping process can be easily automated via a decorator.
94
- Even more, if a user wants to use PyTorch as a backend in a code that mixes
95
- PyTorch and NumPy, it will mostly be the case that it is because they want to
96
- trace through that code. In that setting, TorchDynamo will be able to
97
- automatically do the wrapping/unwrapping.
91
+ This process would be done automatically by TorchDynamo, so we would simply need to write
92
+ ``` python
93
+ @ compile
94
+ def fn (x , y ):
95
+ return np.multiply(x, y).sum()
96
+ ```
97
+
98
+ ### The observable behavior
99
+
100
+ The two main idea driving the design of this compatibility layer were the following:
101
+
102
+ 1 . The behavior of the layer should be as close to that of NumPy as possible
103
+ 2 . The layer follows NumPy master
104
+
105
+ The following design decisions follow from these:
106
+
107
+ ** Default dtypes** . One of the issues that most often user when moving their
108
+ codebase from NumPy to JAX was the default dtype changing from ` float64 ` to
109
+ ` float32 ` . So much so, that this is one noted as one of
110
+ [ JAX's shap edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
111
+ Following the spirit of making everything match NumPy by default, we choose the
112
+ NumPy defaults whenever the ` dtype ` was not chosen in a factory function.
113
+
114
+ ** TODO(Lezcano)** : I just realized that we do not have a clean way to change
115
+ the default dtype of ` torch_np ` to those from PyTorch. We should implement
116
+ that utility flag, similar to
117
+ [ ` torch.set_default_dtype ` ] ( https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html ) .
118
+ Perhaps call it ` torch_np.use_torch_defaults() ` and then add a way for users
119
+ to be able to set their own int/float/complex defaults.
120
+ ** TODO(Lezcano)** : Do we just use them just in factory functions, or do we also
121
+ use them anywhere else -> Check
122
+
123
+ ** NumPy scalars** . NumPy's type system is tricky. At first sight, it looks
124
+ quite a bit like PyTorch's, but having a few more dtypes like ` np.uint16 ` or
125
+ ` np.longdouble ` . Upon closer inspection, one finds that it also has
126
+ [ NumPy scalar] ( https://numpy.org/doc/stable/reference/arrays.scalars.html ) objects.
127
+ NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
128
+ are NumPy's preferred return class for reductions and other operations that
129
+ return just one element. NumPy scalars do not play particularly well with
130
+ computations on devices like GPUs, as they live on CPU. Implementing NumPy
131
+ scalars would mean that we need to synchronize after every ` sum() ` call, which
132
+ is less-than-good. Instead, whenever a NumPy scalar would be returned, we will
133
+ return a 0-D tensor, as PyTorch does.
134
+
135
+ ** Type promotion** . Another not-so-well-known fact of NumPy's cast system is
136
+ that it is data-dependent. Python scalars can be used in pretty much any NumPy
137
+ operation, being able to call any operation that accepts a 0-D array with a
138
+ Python scalar. If you provide an operation with a Python scalar, these will be
139
+ casted to the smallest dtype that can represent them, and then they will
140
+ participate in type promotion, allowing for some rather interesting behaviour
141
+ ``` python
142
+ >> > np.asarray([1 ], dtype = np.int8) + 127
143
+ array([128 ], dtype = int8)
144
+ >> > np.asarray([1 ], dtype = np.int8) + 128
145
+ array([129 ], dtype = int16)
146
+ ```
147
+ This dependent type promotion will be deprecated NumPy 2.0, and will be
148
+ replaced with [ NEP 50] ( https://numpy.org/neps/nep-0050-scalar-promotion.html ) .
149
+ As such, to be forward-looking and for simplicity, we chose to implement the
150
+ type promotion behaviour proposed in NEP 50, which is much closer to that of
151
+ Pytorch.
152
+
153
+ Note that the decision of going with NEP 50 complements the previous one of
154
+ returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
155
+ participate in type promotion in NumPy (but will do in NumPy 2.0):
156
+ ``` python
157
+ int64_0d_array = np.array(1 , dtype = np.int64)
158
+ np.result_type(np.int8, int64_0d_array) == np.int8
159
+ ```
160
+
161
+ ** Versioning** . It should be clear from the previous points that NumPy has a
162
+ fair amount of questionable and legacy pain points. As such, we decided that
163
+ rather than trying to fight these, we would declare that the compat layer
164
+ follows the behavior of Numpy's master. Given the stability of NumPy's API and
165
+ how battle-tested its main functions are, we do not expect this to become a big
166
+ maintenance burden. If anything, it should make our lives easier, as some parts
167
+ of NumPy will soon be simplified and we will not need to implement them, as
168
+ described above.
98
169
99
- ## The ` torch.numpy ` module
170
+
171
+ ## The ` torch_np ` module
100
172
101
173
The bulk of the work went into implementing a system that allows us to
102
174
implement NumPy operations in terms of those of PyTorch. The main design goals
107
179
108
180
We say * most* of NumPy's API, because NumPy's API is not only massive, but also
109
181
there are parts of it which cannot be implemented in PyTorch. For example,
110
- NumPy has support for arrays of strings, dates, and other ` dtype ` s that PyTorch
111
- does not consider. Negative strides are other example. We put together a list
112
- of things that are out of the scope of this project in the
182
+ NumPy has support for arrays of string, datetime, structured and other dtypes.
183
+ Negative strides are other example of a feature that is just out of the scope.
184
+ We put together a list of things that are out of the scope of this project in the
113
185
[ following issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
114
186
115
187
For the bulk of the functions, we started by prioritizing most common
@@ -124,6 +196,11 @@ The second point of preserving NumPy semantics as much as possible will be used
124
196
in the sequel to discuss some points like the default dtypes that are used
125
197
throughout the implementation.
126
198
199
+ ** Visibility of the module** For simplicity, this RFC assumes that the
200
+ ` torch_np ` module will not be public, as the decision for it to be made public
201
+ was met with different opinions. We discuss these in the "Unresolved Questions"
202
+ section.
203
+
127
204
### Annotation-based preprocessing
128
205
129
206
NumPy accepts virtually anything that smells like an array as input to its operators
@@ -138,7 +215,7 @@ array([1, 2, 3, 4, 5, 6])
138
215
139
216
To implement NumPy in terms of PyTorch, for any operation we would need to put
140
217
the inputs into tensors, perform the operations, and then wrap the tensor into
141
- a ` torch.numpy .ndarray` (more on this class later).
218
+ a ` torch_np .ndarray` (more on this class later).
142
219
143
220
To avoid all this code repetition, we implement the functions in two steps.
144
221
@@ -173,16 +250,16 @@ gathering all the inputs at runtime and normalizing them according to their
173
250
annotations.
174
251
175
252
We currently have four annotations (and small variations of them):
176
- - ` ArrayLike ` : The input can be a ` torch.numpy .array` , a list of lists, a
253
+ - ` ArrayLike ` : The input can be a ` torch_np .array` , a list of lists, a
177
254
scalar, or anything that NumPy would accept. It returns a ` torch.Tensor ` .
178
- - ` DTypeLike ` : Takes a ` torch.numpy ` dtype and returns a PyTorch dtype.
255
+ - ` DTypeLike ` : Takes a ` torch_np ` dtype and returns a PyTorch dtype.
179
256
- ` AxisLike ` : Takes anything that can be accepted as an axis (e.g. a tuple or
180
257
an ` ndarray ` ) and returns a tuple.
181
- - ` OutArray ` : Asserts that the input is a ` torch.numpy .ndarray` . This is used
258
+ - ` OutArray ` : Asserts that the input is a ` torch_np .ndarray` . This is used
182
259
to implement the ` out ` arg.
183
260
184
261
Note that none of the code here makes use of NumPy. We are writing
185
- ` torch.numpy .ndarray` above to make more explicit our intents, but there
262
+ ` torch_np .ndarray` above to make more explicit our intents, but there
186
263
shouldn't be any ambiguity here.
187
264
188
265
** OBS(Lezcano)** : ` DTypeLike ` should be ` Optional[DTypeLike] `
@@ -213,114 +290,66 @@ class is rather simple. We simply register all the free functions as methods or
213
290
dunder methods appropriately. We also forward the properties to the properties
214
291
within the PyTorch tensor and we are done.
215
292
216
- ### DTypes
217
-
218
- ** Default dtypes** . One of the issues that most often user when moving their
219
- codebase from NumPy to JAX was the default dtype changing from ` float64 ` to
220
- ` float32 ` . So much so, that this is one noted as one of
221
- [ JAX's shap edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
222
- Following the spirit of making everything match NumPy by default, we choose the
223
- NumPy defaults whenever the ` dtype ` was not chosen in a factory function.
224
-
225
- ** TODO(Lezcano)** : I just realised that we do not have a clean way to change
226
- the default dtype of ` torch.numpy ` to those from PyTorch. We should implement
227
- that utility flag, similar to
228
- [ ` torch.set_default_dtype ` ] ( https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html ) .
229
- Perhaps call it ` torch.numpy.use_torch_defaults() ` and then add a way for users
230
- to be able to set their own int/float/complex defaults.
231
- ** TODO(Lezcano)** : Do we just use them just in factory functions, or do we also
232
- use them anywhere else -> Check
233
-
234
- ** NumPy scalars** . NumPy's type system is tricky. At first sight, it looks
235
- quite a bit like PyTorch's, but having a few more dtypes like ` np.uint16 ` or
236
- ` np.longdouble ` . Upon closer inspection, one finds that it also has
237
- [ NumPy scalar] ( https://numpy.org/doc/stable/reference/arrays.scalars.html ) objects.
238
- NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
239
- are NumPy's preferred return class for reductions and other operations that
240
- return just one element. NumPy scalars do not play particularly well with
241
- computations on devices like GPUs, as they live on CPU. Implementing NumPy
242
- scalars would mean that we need to synchronize after every ` sum() ` call, which
243
- is less-than-good. Instead, whenever a NumPy scalar would be returned, we will
244
- return a 0-D tensor, as PyTorch does.
245
-
246
- ** Type promotion** . Another not-so-well-known fact of NumPy's cast system is
247
- that it is data-dependent. Python scalars can be used in pretty much any NumPy
248
- operation, being able to call any operation that accepts a 0-D array with a
249
- Python scalar. If you provide an operation with a Python scalar, these will be
250
- casted to the smallest dtype that can represent them, and then they will
251
- participate in type promotion, allowing for some rather interesting behaviour
252
- ``` python
253
- >> > np.asarray([1 ], dtype = np.int8) + 127
254
- array([128 ], dtype = int8)
255
- >> > np.asarray([1 ], dtype = np.int8) + 128
256
- array([129 ], dtype = int16)
257
- ```
258
- This dependent type promotion will be deprecated NumPy 2.0, and will be
259
- replaced with [ NEP 50] ( https://numpy.org/neps/nep-0050-scalar-promotion.html ) .
260
- As such, to be forward-looking and for simplicity, we chose to implement the
261
- type promotion behaviour proposed in NEP 50, which is much closer to that of
262
- Pytorch.
263
-
264
- Note that the decision of going with NEP 50 complements the previous one of
265
- returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
266
- participate in type promotion in NumPy (but will do in NumPy 2.0):
267
- ``` python
268
- int64_0d_array = np.array(1 , dtype = np.int64)
269
- np.result_type(np.int8, int64_0d_array) == np.int8
270
- ```
271
-
272
- ## Testing
293
+ ### Testing
273
294
274
295
The testing of the framework was done via ~~ copying~~ vendoring tests from the
275
296
NumPy test suit. Then, we would replace the NumPy imports for imports with
276
- ` torch.numpy ` . The failures on these tests were then triaged and discussed the
297
+ ` torch_np ` . The failures on these tests were then triaged and discussed the
277
298
priority of fixing each of them.
278
299
279
300
In the (near) future, we plan to get some real world examples and run them
280
301
through the library, to test its coverage and correctness.
281
302
282
- ## Limitations
283
-
284
- One of the known limitations of this approach is the efficiency in eager.
285
- Similar to PrimTorch, sometimes we needed to work around some limitations of
286
- PyTorch (e.g. support for some operations for ` float16 ` ) or some ways PyTorch
287
- deviates from NumPy by implementing things manually calling several ` torch `
288
- operations. This, when executed in eager mode and, in particular, on CUDA
289
- devices, will result on a perf-hit. To alleviate this, we tried to dispatch
290
- NumPy functions to PyTorch functions with as few indirections as possible, to
291
- alleviate the number of kernels called when executed on eager mode.
303
+ ### Limitations
292
304
293
- There are some known limitations. Some of them are tracked in the second part
294
- of the [ OP of this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
305
+ A number of known limitations are tracked in the second part of the
306
+ [ OP of this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
295
307
There are some more in [ this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86 ) .
296
308
When landing all this, we will create a comprehensive document with the differences
297
- between NumPy and ` torch.numpy ` .
309
+ between NumPy and ` torch_np ` .
298
310
299
- ## Beyond NumPy
311
+ ### Beyond Plain NumPy
300
312
301
- ** CUDA** . The current implementation has just been implemented and tested on
302
- CPU. We expect CUDA coverage to be as good as the coverage we have with CPU
303
- matching CUDA. In the NumPy-only example in the introduction, given that no
304
- explicit ` device ` kwarg is used anywhere in this module, CUDA execution could
305
- be turned on via ` with torch.device('cuda'): ` . In the PyTorch+NumPy example, if
306
- the original tensors are on GPU, the whole execution should be performed on the
307
- GPU.
313
+ ** GPU** . The current implementation has just been implemented and tested on
314
+ CPU. We expect GPU coverage to be as good as the coverage we have with CPU
315
+ matching GPU. If the original tensors are on GPU, the whole execution should
316
+ be performed on the GPU.
308
317
309
318
** TODO(Lezcano)** . We should probably test CUDA on the tests.
310
319
311
320
** Gradients** . We have not tested gradient tracking either as we are still to
312
321
find some good examples on which to test it, but it should be a simple
313
- corollary of all this effort. In the PyTorch+NumPy scenario, if the original
314
- tensors fed into the function do have ` requires_grad=True ` , the tensors will
315
- track the gradients of the internal implementation and then the user could
316
- differentiate through the NumPy code. We do not have a way to turn the
317
- ` requires_grad ` flag in the all-NumPy case. Note that this is expected as this
318
- would require exposing all the autograd machinery from PyTorch into the API. If
319
- a user wants to compute gradients in their program, we expect them to wrap it
320
- in a function and apply the PyTorch-NumPy approach.
322
+ corollary of all this effort. If the original tensors fed into the function do
323
+ have ` requires_grad=True ` , the tensors will track the gradients of the internal
324
+ implementation and then the user could differentiate through the NumPy code.
321
325
322
326
** TODO(Lezcano)** . Picking up simple NumPy programs from the internet would be good for these autograd tests.
323
327
324
- # Bindings to TorchDyamo
328
+ ### Bindings to TorchDyamo
329
+
330
+ The bindings for NumPy at the TorchDynamo level are currently being developed at [ #95849 ] ( https://github.com/pytorch/pytorch/pull/95849 ) .
331
+
332
+
333
+ ## Unresolved Questions
334
+
335
+ A question was left open in the initial discussion. Should the module ` torch_np ` be publicly exposed as ` torch.numpy ` or not?
336
+
337
+ A few arguments in favor of making it public:
338
+ * People could use it in their NumPy programs just by changing the import to
339
+ ` import torch.numpy as np ` . This could be a selling point similar to JAX's
340
+ ` jax.numpy ` , which could incentivize adoption.
341
+ * People would not need to use the whole PyTorch 2.0 stack to start using
342
+ PyTorch in their codebases
343
+ * See [ this experiment in scikit-learn] ( https://github.com/scikit-learn/scikit-learn/pull/25956 )
344
+ where they got a 7x speed-up on CPU on a layer just by using ` torch.linalg ` .
345
+ * Since the layer is rather thin and in pure Python, if there are bugs,
346
+ external contributors could easily help fixing them or extend the supported
347
+ functionality.
325
348
326
- ** TODO(Lezcano)** : The PR is not there yet cf. [ #95849 ] ( https://github.com/pytorch/pytorch/pull/95849 ) .
349
+ A few arguments against:
350
+ * The compat introduces a number of type conversions that may produce somewhat
351
+ slow code when used in eager mode.
352
+ * [ Note] Keeping this in mind, we tried to use in the implementations as few
353
+ operators as possible, to make it reasonably fast in eager mode.
354
+ * Exposing ` torch.numpy ` would create a less performant secondary entry point
355
+ to many of the functions in PyTorch. This could be a trap for new users.
0 commit comments