@@ -16,14 +16,12 @@ The this project has a main goal as per the
16
16
[ initial design document] ( https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8 ) :
17
17
1 . Make TorchDynamo understand NumPy calls
18
18
19
- The work is being done at [ numpy_pytorch_interop] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/ ) .
19
+ The work is currently being done at [ numpy_pytorch_interop] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/ ) .
20
20
21
21
22
22
## Motivation
23
23
24
- ### An introductory example
25
-
26
- Let's start with some examples.
24
+ ### Introductory examples
27
25
28
26
Consider the following snippet:
29
27
``` python
@@ -46,8 +44,8 @@ z = torch.matmul(x, y)
46
44
w = z.sum()
47
45
```
48
46
49
- Here we already see a couple differences between NumPy and PyTorch. The most
50
- obvious one is that the default dtype in NumPy is ` float64 ` rather than
47
+ Here, we can already spot a couple differences between NumPy and PyTorch.
48
+ The most obvious one is that the default dtype in NumPy is ` float64 ` rather than
51
49
` float32 ` . The less obvious is very sneakily hiding in the last line.
52
50
53
51
``` python
@@ -57,10 +55,10 @@ obvious one is that the default dtype in NumPy is `float64` rather than
57
55
58
56
Reductions and similar operations in NumPy return the infamous NumPy scalars.
59
57
We'll discuss these and other NumPy quirks and how we dealt with them in the
60
- sequel .
58
+ [ design decision section ] ( #design-decisions ) .
61
59
62
- As expected, this layer also allows for combining NumPy code and PyTorch code.
63
60
61
+ Let's now have a look at a toy example of how this layer would be used.
64
62
``` python
65
63
import torch
66
64
import numpy as np
@@ -75,41 +73,32 @@ t_results = torch.empty(5, dtype=torch.float64)
75
73
t_results[0 ] = result # store the result in a torch.Tensor
76
74
```
77
75
78
- This code mixing NumPy and PyTorch already works, as ` torch.Tensor ` implements
79
- the ` __array__ ` method. For it to work manually with the compatibility layer, we would
80
- need to wrap and unwrap the inputs / outputs. This could be done modifying ` fn `
81
- as
82
-
83
- ``` python
84
- def fn (x , y ):
85
- x = np.asarray(x)
86
- y = np.asarray(y)
87
- ret = np.multiply(x, y).sum()
88
- return ret.tensor.numpy()
89
- ```
76
+ Note that this code mixing NumPy and PyTorch already works, as ` torch.Tensor `
77
+ implements the ` __array__ ` method. Now, the compatibility layer allows us to
78
+ trace through it. In order to do that, there would be no necessary changes,
79
+ other than simply ask ` torch.compile ` to trace through it:
90
80
91
- This process would be done automatically by TorchDynamo, so we would simply need to write
92
81
``` python
93
82
@ compile
94
83
def fn (x , y ):
95
84
return np.multiply(x, y).sum()
96
85
```
97
86
98
- ### The observable behavior
87
+ ### Design decisions
99
88
100
- The two main idea driving the design of this compatibility layer were the following:
89
+ The two main ideas driving the design of this compatibility layer are the following:
101
90
102
91
1 . The behavior of the layer should be as close to that of NumPy as possible
103
92
2 . The layer follows NumPy master
104
93
105
94
The following design decisions follow from these:
106
95
107
- ** Default dtypes** . One of the issues that most often user when moving their
108
- codebase from NumPy to JAX was the default dtype changing from ` float64 ` to
109
- ` float32 ` . So much so, that this is one noted as one of
96
+ ** Default dtypes** . One of the most common issues that bites people when migrating their
97
+ codebases from NumPy to JAX is the default dtype changing from ` float64 ` to
98
+ ` float32 ` . So much so that this is noted as one of
110
99
[ JAX's shap edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
111
100
Following the spirit of making everything match NumPy by default, we choose the
112
- NumPy defaults whenever the ` dtype ` was not chosen in a factory function.
101
+ NumPy defaults whenever the ` dtype ` was not made explicit in a factory function.
113
102
114
103
** TODO(Lezcano)** : I just realized that we do not have a clean way to change
115
104
the default dtype of ` torch_np ` to those from PyTorch. We should implement
@@ -121,23 +110,34 @@ to be able to set their own int/float/complex defaults.
121
110
use them anywhere else -> Check
122
111
123
112
** NumPy scalars** . NumPy's type system is tricky. At first sight, it looks
124
- quite a bit like PyTorch's, but having a few more dtypes like ` np.uint16 ` or
125
- ` np.longdouble ` . Upon closer inspection, one finds that it also has
113
+ like PyTorch's, but with few more dtypes like ` np.uint16 ` or ` np.longdouble ` .
114
+ Upon closer inspection, one finds that it also has
126
115
[ NumPy scalar] ( https://numpy.org/doc/stable/reference/arrays.scalars.html ) objects.
127
116
NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
128
117
are NumPy's preferred return class for reductions and other operations that
129
118
return just one element. NumPy scalars do not play particularly well with
130
119
computations on devices like GPUs, as they live on CPU. Implementing NumPy
131
120
scalars would mean that we need to synchronize after every ` sum() ` call, which
132
- is less-than-good. Instead, whenever a NumPy scalar would be returned, we will
133
- return a 0-D tensor, as PyTorch does.
121
+ would be terrible performance-wise. In this implementation, we choose to represent
122
+ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
123
+
124
+ ``` python
125
+ >> > np.int32(2 ) * [1 , 2 , 3 ] # scalar decays to a python int
126
+ [1 , 2 , 3 , 1 , 2 , 3 ]
127
+
128
+ >> > np.asarray(2 ) * [1 , 2 , 3 ] # zero-dim array is an array-like
129
+ array([2 , 4 , 6 ])
130
+ ```
131
+
132
+ but we don't expect these to pose a big issue in practice. Note that in this
133
+ implementation ` torch_np.int32(2) ` would return the same as ` torch_np.asarray(2) ` .
134
134
135
135
** Type promotion** . Another not-so-well-known fact of NumPy's cast system is
136
136
that it is data-dependent. Python scalars can be used in pretty much any NumPy
137
137
operation, being able to call any operation that accepts a 0-D array with a
138
138
Python scalar. If you provide an operation with a Python scalar, these will be
139
- casted to the smallest dtype that can represent them , and then they will
140
- participate in type promotion, allowing for some rather interesting behaviour
139
+ casted to the smallest dtype they can be represented in , and then, they will
140
+ participate in type promotion. This allows for for some rather interesting behaviour
141
141
``` python
142
142
>> > np.asarray([1 ], dtype = np.int8) + 127
143
143
array([128 ], dtype = int8)
@@ -146,64 +146,63 @@ array([129], dtype=int16)
146
146
```
147
147
This dependent type promotion will be deprecated NumPy 2.0, and will be
148
148
replaced with [ NEP 50] ( https://numpy.org/neps/nep-0050-scalar-promotion.html ) .
149
- As such, to be forward-looking and for simplicity , we chose to implement the
149
+ For simplicity and to be forward-looking, we chose to implement the
150
150
type promotion behaviour proposed in NEP 50, which is much closer to that of
151
151
Pytorch.
152
152
153
153
Note that the decision of going with NEP 50 complements the previous one of
154
154
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
155
- participate in type promotion in NumPy (but will do in NumPy 2.0):
155
+ participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50 ):
156
156
``` python
157
157
int64_0d_array = np.array(1 , dtype = np.int64)
158
158
np.result_type(np.int8, int64_0d_array) == np.int8
159
159
```
160
160
161
161
** Versioning** . It should be clear from the previous points that NumPy has a
162
- fair amount of questionable and legacy pain points. As such, we decided that
163
- rather than trying to fight these, we would declare that the compat layer
164
- follows the behavior of Numpy's master. Given the stability of NumPy's API and
165
- how battle-tested its main functions are, we do not expect this to become a big
166
- maintenance burden. If anything, it should make our lives easier, as some parts
167
- of NumPy will soon be simplified and we will not need to implement them, as
168
- described above.
162
+ fair amount of questionable and legacy pain points. It is for this reason that
163
+ we decided that rather than fighting these, we would declare that the compat
164
+ layer follows the behavior of Numpy's master (even, in some cases, of NumPy
165
+ 2.0). Given the stability of NumPy's API and how battle-tested its main
166
+ functions are, we do not expect this to become a big maintenance burden. If
167
+ anything, it should make our lives easier, as some parts of NumPy will soon be
168
+ simplified, saving us the pain of having to implement all the pre-existing
169
+ corner-cases.
170
+
171
+ For reference NumPy 2.0 is expected to land at the end of this year.
169
172
170
173
171
174
## The ` torch_np ` module
172
175
173
176
The bulk of the work went into implementing a system that allows us to
174
177
implement NumPy operations in terms of those of PyTorch. The main design goals
175
- were
178
+ here were
176
179
177
180
1 . Implement * most* of NumPy's API
178
181
2 . Preserve NumPy semantics as much as possible
179
182
180
183
We say * most* of NumPy's API, because NumPy's API is not only massive, but also
181
184
there are parts of it which cannot be implemented in PyTorch. For example,
182
185
NumPy has support for arrays of string, datetime, structured and other dtypes.
183
- Negative strides are other example of a feature that is just out of the scope .
186
+ Negative strides are other example of a feature that is just not supported in PyTorch .
184
187
We put together a list of things that are out of the scope of this project in the
185
188
[ following issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
186
189
187
- For the bulk of the functions, we started by prioritizing most common
188
- operations. Then, when bringing tests from the NumPy test suit and running
189
- them, we would triage and prioritize how important was to fix each failure we
190
- found. Iterating this process, we ended up with a small list of differences
191
- between the NumPy and the PyTorch API which we sorted out by hand and finished
192
- implementing. That list and the prioritization discussion can be found in
193
- [ the first few posts of this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87 ) .
194
-
195
- The second point of preserving NumPy semantics as much as possible will be used
196
- in the sequel to discuss some points like the default dtypes that are used
197
- throughout the implementation.
190
+ For the bulk of the functions, we started by prioritizing the most common
191
+ operations. Then, when bringing tests from the NumPy test suit, we would triage
192
+ and prioritize how important was to fix each failure we found. Iterating this
193
+ process, we ended up with a small list of differences between the NumPy and the
194
+ PyTorch API which we prioritized by hand. That list and the prioritization
195
+ discussion can be found in [ this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87 ) .
198
196
199
197
** Visibility of the module** For simplicity, this RFC assumes that the
200
198
` torch_np ` module will not be public, as the decision for it to be made public
201
- was met with different opinions. We discuss these in the "Unresolved Questions"
202
- section.
199
+ was met with different opinions.
200
+ We discuss these in the section [ unresolved questions ] ( #unresolved-questions ) .
203
201
204
202
### Annotation-based preprocessing
205
203
206
- NumPy accepts virtually anything that smells like an array as input to its operators
204
+ NumPy accepts virtually anything that smells like an array as an input
205
+
207
206
``` python
208
207
>> > np.add(1 , 3 )
209
208
4
@@ -213,8 +212,8 @@ array([6., 7., 8.])
213
212
array([1 , 2 , 3 , 4 , 5 , 6 ])
214
213
```
215
214
216
- To implement NumPy in terms of PyTorch, for any operation we would need to put
217
- the inputs into tensors, perform the operations, and then wrap the tensor into
215
+ To implement NumPy in terms of PyTorch, for any operation we would need to map
216
+ inputs into tensors, perform the operations, and then wrap the tensor into
218
217
a ` torch_np.ndarray ` (more on this class later).
219
218
220
219
To avoid all this code repetition, we implement the functions in two steps.
@@ -223,31 +222,33 @@ First, we implement functions with the NumPy signature, but assuming that in
223
222
place of NumPy-land elements (` np.array ` , array-like functions, ` np.dtype ` s, etc)
224
223
they simply accept ` torch.Tensor ` and PyTorch-land objects and return
225
224
` torch.Tensor ` s. For example, we would implement ` np.diag ` as
225
+
226
226
``` python
227
227
def diag (v , k = 0 ):
228
228
return torch.diag(v, k)
229
229
```
230
+
230
231
In this layer, if a NumPy function is composite (calls other NumPy functions
231
232
internally), we can simply vendor its implementation, and have it call our
232
233
PyTorch-land implementations of these functions. In other words, at this level,
233
- functions are composable, as any set of functions implemented purely in
234
- PyTorch. All these implementations are internal, and are not meant to be seen
235
- or used by the final user.
234
+ functions are composable, as they are simply regular PyTorch functions.
235
+ All these implementations are internal, and are not meant to be seen or used
236
+ by the final user.
236
237
237
238
The second step is then done via type annotations and a decorator. Each type
238
- annotation has then a map NumPy-land -> PyTorch-land associated, that maps the
239
- set of inputs accepted by NumPy for that argument into a PyTorch-land object
240
- (think a ` torch.Tensor ` or a PyTorch dtype). For example, for ` np.diag ` we
241
- would write
239
+ annotation has an associated function from NumPy-land into PyTorch-land. This
240
+ function converts the set of inputs accepted by NumPy for that argument into a
241
+ PyTorch-land object (think a ` torch.Tensor ` or a PyTorch dtype). For example,
242
+ for ` np.diag ` we would write
243
+
242
244
``` python
243
245
def diag (v : ArrayLike, k = 0 ):
244
246
return torch.diag(v, k)
245
247
```
246
248
247
- Then, we would wrap these Python-land functions in a ` normalizer ` decorator and
248
- expose them in the public ` torch.np ` module. This decorator is in charge of
249
- gathering all the inputs at runtime and normalizing them according to their
250
- annotations.
249
+ Then, we wrap these Python-land functions with a ` normalizer ` decorator and
250
+ expose them in the ` torch_np ` module. This decorator is in charge of gathering
251
+ all the inputs at runtime and normalizing them according to their annotations.
251
252
252
253
We currently have four annotations (and small variations of them):
253
254
- ` ArrayLike ` : The input can be a ` torch_np.array ` , a list of lists, a
@@ -258,9 +259,9 @@ We currently have four annotations (and small variations of them):
258
259
- ` OutArray ` : Asserts that the input is a ` torch_np.ndarray ` . This is used
259
260
to implement the ` out ` arg.
260
261
261
- Note that none of the code here makes use of NumPy. We are writing
262
- ` torch_np.ndarray ` above to make more explicit our intents, but there
263
- shouldn't be any ambiguity here .
262
+ Note that none of the code in this implementation makes use of NumPy. We are
263
+ writing ` torch_np.ndarray ` above to make more explicit our intents, but there
264
+ shouldn't be any ambiguity.
264
265
265
266
** OBS(Lezcano)** : ` DTypeLike ` should be ` Optional[DTypeLike] `
266
267
** OBS(Lezcano)** : Should we have a ` NotImplementedType ` to mark the args that
@@ -271,30 +272,28 @@ implementation, or mark explicitly those that we don't use.
271
272
272
273
** Implmenting out** : In PyTorch, the ` out ` kwarg is, as the name says, a
273
274
keyword-only argument. It is for this reason that, in PrimTorch, we were able
274
- to implement it as
275
- [ a decorator] ( https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282 ) .
275
+ to implement it as [ a decorator] ( https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282 ) .
276
276
This is not the case in NumPy. In NumPy ` out ` is a positional arg that is often
277
277
interleaved with other parameters. This is the reason why we use the ` OutArray `
278
- label to mark these. We then implement the ` out ` semantics in the ` @normalizer `
278
+ annotation to mark these. We then implement the ` out ` semantics in the ` @normalizer `
279
279
wrapper in a generic way.
280
280
281
281
** Ufuncs and reductions** : Ufuncs (unary and binary) and reductions are two
282
282
sets of functions that are particularly regular. For these functions, we
283
- implement (some of) their args in a generic way. We then simply forward the
284
- computations to PyTorch, perhaps working around some PyTorch limitations.
285
-
286
- ### The ` ndarray ` class
283
+ implement their args in a generic way as a preprocessing or postprocessing.
287
284
288
- Once we have all the free functions implemented, implementing an ` ndarray `
289
- class is rather simple. We simply register all the free functions as methods or
290
- dunder methods appropriately. We also forward the properties to the properties
291
- within the PyTorch tensor and we are done.
285
+ ** The ndarray class** Once we have all the free functions implemented as
286
+ functions form ` torch_np.ndarray ` s to ` torch_np.ndarray ` s, implementing the
287
+ methods from the ` ndarray ` class is rather simple. We simply register all the
288
+ free functions as methods or dunder methods appropriately. We also forward the
289
+ properties to the properties within the PyTorch tensor and we are done.
290
+ This creates a circular dependency which we break with a local import.
292
291
293
292
### Testing
294
293
295
294
The testing of the framework was done via ~~ copying~~ vendoring tests from the
296
- NumPy test suit. Then, we would replace the NumPy imports for imports with
297
- ` torch_np ` . The failures on these tests were then triaged and discussed the
295
+ NumPy test suit. Then, we would replace the NumPy imports with ` torch_np `
296
+ imports . The failures on these tests were then triaged and discussed the
298
297
priority of fixing each of them.
299
298
300
299
In the (near) future, we plan to get some real world examples and run them
@@ -305,7 +304,7 @@ through the library, to test its coverage and correctness.
305
304
A number of known limitations are tracked in the second part of the
306
305
[ OP of this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
307
306
There are some more in [ this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86 ) .
308
- When landing all this, we will create a comprehensive document with the differences
307
+ When landing this RFC , we will create a comprehensive document with the differences
309
308
between NumPy and ` torch_np ` .
310
309
311
310
### Beyond Plain NumPy
@@ -332,7 +331,8 @@ The bindings for NumPy at the TorchDynamo level are currently being developed at
332
331
333
332
## Unresolved Questions
334
333
335
- A question was left open in the initial discussion. Should the module ` torch_np ` be publicly exposed as ` torch.numpy ` or not?
334
+ A question was left open in the initial discussion. Should the module
335
+ ` torch_np ` be publicly exposed as ` torch.numpy ` or not?
336
336
337
337
A few arguments in favor of making it public:
338
338
* People could use it in their NumPy programs just by changing the import to
0 commit comments