@@ -92,23 +92,24 @@ def fn(x, y):
92
92
93
93
Then, TorchDynamo would will cast ` x ` and ` y ` to our internal implementation of ` ndarray ` ,
94
94
and will dispatch ` np.multiply ` and ` sum ` to our implementations in terms of ` torch `
95
- functions effectively turning this function into a pure PyTorch function.
95
+ functions, effectively turning this function into a pure PyTorch function.
96
96
97
97
### Design decisions
98
98
99
99
The main ideas driving the design of this compatibility layer are the following:
100
100
101
- 1 . The goal is to transform valid NumPy programs into their equivalent PyTorch
101
+ 1 . The goal is to transform valid NumPy and mixed PyTorch-NumPy programs into
102
+ their equivalent PyTorch-only execution.
102
103
2 . The behavior of the layer should be as close to that of NumPy as possible
103
104
3 . The layer follows the most recent NumPy release
104
105
105
106
The following design decisions follow from these:
106
107
107
- ** A superset of NumPy** . Same as PyTorch has spotty support for ` float16 ` on
108
- CPU, and less-than-good support for ` complex32 ` , NumPy has a number of
109
- well-known edge-cases. The decision of translating just valid NumPy programs,
110
- often allows us to implement a superset of the functionality of NumPy with more
111
- predictable and consistent behavior than NumPy itself.
108
+ ** A superset of NumPy** . NumPy has a number of well-known edge-cases (as does
109
+ PyTorch, like spotty support for ` float16 ` on CPU and ` complex32 ` in general).
110
+ The decision to translate only valid NumPy programs, often allows us to
111
+ implement a superset of the functionality of NumPy with more predictable and
112
+ consistent behavior than NumPy itself has .
112
113
113
114
** Exceptions may be different** . We avoid entirely modelling the exception
114
115
system in NumPy. As seen in the implementation of PrimTorch, modelling the
@@ -118,19 +119,20 @@ and we choose not to offer any guarantee here.
118
119
** Default dtypes** . One of the most common issues that bites people when migrating their
119
120
codebases from NumPy to JAX is the default dtype changing from ` float64 ` to
120
121
` float32 ` . So much so that this is noted as one of
121
- [ JAX's shap edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
122
+ [ JAX's sharp edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
122
123
Following the spirit of making everything match NumPy by default, we choose the
123
- NumPy defaults whenever the ` dtype ` was not made explicit in a factory function.
124
+ NumPy default dtype whenever the ` dtype ` was not made explicit in a factory function.
124
125
We also provide a function ` set_default_dtype ` that allows to change this behavior
125
126
dynamically.
126
127
127
128
** NumPy scalars** . NumPy's type system is tricky. At first sight, it looks
128
129
like PyTorch's, but with few more dtypes like ` np.uint16 ` or ` np.longdouble ` .
129
130
Upon closer inspection, one finds that it also has
130
131
[ NumPy scalar] ( https://numpy.org/doc/stable/reference/arrays.scalars.html ) objects.
131
- NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
132
- are NumPy's preferred return class for reductions and other operations that
133
- return just one element. NumPy scalars do not play particularly well with
132
+ NumPy scalars are similar to Python scalars but with a fixed precision and
133
+ array-like methods attached. NumPy scalars are NumPy's preferred return class
134
+ for reductions and other operations that return just one element.
135
+ NumPy scalars do not play particularly well with
134
136
computations on devices like GPUs, as they live on CPU. Implementing NumPy
135
137
scalars would mean that we need to synchronize after every ` sum() ` call, which
136
138
would be terrible performance-wise. In this implementation, we choose to represent
@@ -149,25 +151,26 @@ We don't expect these to pose a big issue in practice. Note that in the
149
151
proposed implementation ` np.int32(2) ` would return the same as ` np.asarray(2) ` .
150
152
In general, we try to avoid unnecessary graph breaks whenever we can. For
151
153
example, we may choose to return a tensor of shape ` (2, *) ` rather than a list
152
- of pairs, to avoid unnecessary graph breaks .
154
+ of pairs, to avoid a graph break .
153
155
154
- ** Type promotion** . Another not-so-well-known fact of NumPy's cast system is
155
- that it is data-dependent. Python scalars can be used in pretty much any NumPy
156
+ ** Type promotion** . Another not-so-well-known fact of NumPy's dtype system and casting rules
157
+ is that it is data-dependent. Python scalars can be used in pretty much any NumPy
156
158
operation, being able to call any operation that accepts a 0-D array with a
157
159
Python scalar. If you provide an operation with a Python scalar, these will be
158
- casted to the smallest dtype they can be represented in, and then, they will
160
+ cast to the smallest dtype they can be represented in, and only then will they
159
161
participate in type promotion. This allows for for some rather interesting behaviour
160
162
``` python
161
163
>> > np.asarray([1 ], dtype = np.int8) + 127
162
164
array([128 ], dtype = int8)
163
165
>> > np.asarray([1 ], dtype = np.int8) + 128
164
166
array([129 ], dtype = int16)
165
167
```
166
- This data-dependent type promotion will be deprecated NumPy 2.0, and will be
167
- replaced with [ NEP 50] ( https://numpy.org/neps/nep-0050-scalar-promotion.html ) .
168
+ This data-dependent type promotion will be removed in NumPy 2.0 (planned for Dec'23), and will be
169
+ replaced with [ NEP 50] ( https://numpy.org/neps/nep-0050-scalar-promotion.html )
170
+ (already implemented in NumPy, it needs to be enabled via a private global switch now).
168
171
For simplicity and to be forward-looking, we chose to implement the
169
172
type promotion behaviour proposed in NEP 50, which is much closer to that of
170
- Pytorch .
173
+ PyTorch .
171
174
172
175
Note that the decision of going with NEP 50 complements the previous one of
173
176
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
@@ -178,24 +181,22 @@ np.result_type(np.int8, int64_0d_array) == np.int8
178
181
```
179
182
180
183
** Versioning** . It should be clear from the previous points that NumPy has a
181
- fair amount of questionable and legacy pain points. It is for this reason that
184
+ fair amount of questionable behavior and legacy pain points. It is for this reason that
182
185
we decided that rather than fighting these, we would declare that the compat
183
- layer follows the behavior of Numpy 's most recent release (even, in some cases,
186
+ layer follows the behavior of NumPy 's most recent release (even, in some cases,
184
187
of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its
185
188
main functions are, we do not expect this to become a big maintenance burden.
186
189
If anything, it should make our lives easier, as some parts of NumPy will soon
187
190
be simplified, saving us the pain of having to implement all the pre-existing
188
191
corner-cases.
189
192
190
- For reference NumPy 2.0 is expected to land at the end of this year.
191
-
192
193
** Randomness** . PyTorch and NumPy use different random number generation methods.
193
194
In particular, NumPy recently moved to a [ new API] ( https://numpy.org/doc/stable/reference/random/index.html )
194
- with a ` Generator ` object which has sampling methods on it. The current compat.
195
- layer does not implement this new API, as the default bit generator in NumPy is a
196
- ` PCG64 ` , while on PyTorch we use a ` MT19937 ` on CPU and a ` Philox ` . From this, it
197
- follows that this API will not give any reproducibility guarantees when it comes
198
- to randomness.
195
+ with a ` Generator ` object which has sampling methods on it. The current compat
196
+ layer does not implement this new API, as the default bit generator in NumPy is
197
+ ` PCG64 ` , while on PyTorch we use ` MT19937 ` on CPU and ` Philox ` on non-CPU devices.
198
+ From this, it follows that this API will not give any reproducibility
199
+ guarantees when it comes to randomness.
199
200
200
201
201
202
## The ` torch_np ` module
@@ -210,20 +211,21 @@ here were
210
211
We say * most* of NumPy's API, because NumPy's API is not only massive, but also
211
212
there are parts of it which cannot be implemented in PyTorch. For example,
212
213
NumPy has support for arrays of string, datetime, structured and other dtypes.
213
- Negative strides are other example of a feature that is just not supported in PyTorch.
214
+ Negative strides are another example of a feature that is not supported in PyTorch.
214
215
We put together a list of things that are out of the scope of this project in the
215
216
[ following issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
216
217
217
218
For the bulk of the functions, we started by prioritizing the most common
218
- operations. Then, when bringing tests from the NumPy test suite, we would triage
219
- and prioritize how important was to fix each failure we found. Iterating this
220
- process , we ended up with a small list of differences between the NumPy and the
221
- PyTorch API which we prioritized by hand. That list and the prioritization
219
+ operations. Then, when bringing tests from the NumPy test suite, we triaged
220
+ and prioritized how important it was to fix each failure we found. Doing this
221
+ iteratively , we ended up with a small list of differences between the NumPy and
222
+ PyTorch APIs, which we prioritized by hand. That list and the prioritization
222
223
discussion can be found in [ this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87 ) .
223
224
224
225
** Visibility of the module** For simplicity, this RFC assumes that the
225
- ` torch_np ` module will not be public, as the decision for it to be made public
226
- was met with different opinions.
226
+ ` torch_np ` module will not be public, as the initial suggestion for it to be
227
+ made public was met with mixed opinions. This topic can be revisited in the
228
+ future if desired.
227
229
We discuss these in the section [ unresolved questions] ( #unresolved-questions ) .
228
230
229
231
### Annotation-based preprocessing
@@ -261,7 +263,7 @@ internally), we can simply vendor its implementation, and have it call our
261
263
PyTorch-land implementations of these functions. In other words, at this level,
262
264
functions are composable, as they are simply regular PyTorch functions.
263
265
All these implementations are internal, and are not meant to be seen or used
264
- by the final user.
266
+ by the end user.
265
267
266
268
The second step is then done via type annotations and a decorator. Each type
267
269
annotation has an associated function from NumPy-land into PyTorch-land. This
@@ -287,41 +289,43 @@ We currently have four annotations (and small variations of them):
287
289
- ` AxisLike ` : Takes anything that can be accepted as an axis (e.g. a tuple or
288
290
an ` ndarray ` ) and returns a tuple.
289
291
- ` OutArray ` : Asserts that the input is a ` torch_np.ndarray ` . This is used
290
- to implement the ` out ` arg .
292
+ to implement the ` out ` keyword .
291
293
292
294
Note that none of the code in this implementation makes use of NumPy. We are
293
- writing ` torch_np.ndarray ` above to make more explicit our intents , but there
295
+ writing ` torch_np.ndarray ` above to make more explicit our intent , but there
294
296
shouldn't be any ambiguity.
295
297
296
- ** Implmenting out** : In PyTorch, the ` out ` kwarg is, as the name says, a
297
- keyword-only argument. It is for this reason that, in PrimTorch, we were able
298
- to implement it as [ a decorator] ( https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282 ) .
299
- This is not the case in NumPy. In NumPy ` out ` is a positional arg that is often
300
- interleaved with other parameters. This is the reason why we use the ` OutArray `
301
- annotation to mark these. We then implement the ` out ` semantics in the ` @normalizer `
302
- wrapper in a generic way.
298
+ ** Implementing ` out ` ** : In PyTorch, the ` out ` kwarg is a keyword-only argument.
299
+ It is for this reason that, in PrimTorch, we were able to implement it as [ a
300
+ decorator] ( https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282 ) .
301
+ This is not the case in NumPy. In NumPy, ` out ` can be used both as a positional
302
+ and a keyword argument, and is often interleaved with other parameters. This is
303
+ the reason why we use the ` OutArray ` annotation to mark these. We then
304
+ implement the ` out ` semantics in the ` @normalizer ` wrapper in a generic way.
303
305
304
306
** Ufuncs and reductions** : Ufuncs (unary and binary) and reductions are two
305
307
sets of functions that are particularly regular. For these functions, we
306
- implement their args in a generic way as a preprocessing or postprocessing.
308
+ implement support for their arguments in a generic way as a preprocessing or
309
+ postprocessing step.
307
310
308
- ** The ndarray class** Once we have all the free functions implemented as
309
- functions form ` torch_np.ndarray ` s to ` torch_np.ndarray ` s, implementing the
311
+ ** The ` ndarray ` class** Once we have all the free functions implemented as
312
+ functions from ` torch_np.ndarray ` s to ` torch_np.ndarray ` s, implementing the
310
313
methods from the ` ndarray ` class is rather simple. We simply register all the
311
314
free functions as methods or dunder methods appropriately. We also forward the
312
- properties to the properties within the PyTorch tensor and we are done.
313
- This creates a circular dependency which we break with a local import.
315
+ properties of ` ndarray to the corresponding properties of ` torch.Tensor` and we
316
+ are done. This creates a circular dependency which we break with a local
317
+ import.
314
318
315
319
### Testing
316
320
317
321
The testing of the framework was done via ~~ copying~~ vendoring tests from the
318
322
NumPy test suite. Then, we would replace the NumPy imports with ` torch_np `
319
- imports. The failures on these tests were then triaged and discussed the
320
- priority of fixing each of them .
323
+ imports. The failures on these tests were then triaged, and either fixed or marked
324
+ ` xfail ` depending on our assessment of the priority of implementing a fix .
321
325
322
326
In the end, to have a last check that this tool was sound, we pulled five
323
- examples of NumPy code from different sources and we run it with this library.
324
- We were able to successfully the five examples successfully with close to no code changes.
327
+ examples of NumPy code from different sources and ran it with this library (eager mode execution) .
328
+ We were able to run the five examples successfully with close to no code changes.
325
329
You can read about these in the [ README] ( https://github.com/Quansight-Labs/numpy_pytorch_interop ) .
326
330
327
331
### Limitations
@@ -331,25 +335,26 @@ A number of known limitations are tracked in the second part of the
331
335
When landing this RFC, we will create a comprehensive document with the differences
332
336
between NumPy and ` torch_np ` .
333
337
334
- ### Beyond Plain NumPy
338
+ ### Beyond plain NumPy
335
339
336
- ** GPU** . The current implementation has just been implemented and tested on
337
- CPU. We expect GPU coverage to be as good as the coverage we have with CPU
338
- matching GPU . If the original tensors are on GPU, the whole execution should
339
- be performed on the GPU.
340
+ ** GPU** . The current implementation so far only been implemented and tested on
341
+ CPU. We expect GPU coverage to be as good as the coverage we have with CPU-GPU
342
+ matching tests in the PyTorch test suite . If the original tensors are on GPU,
343
+ the execution should be performed fully on GPU.
340
344
341
345
** Gradients** . We have not tested gradient tracking either as we are still to
342
346
find some good examples on which to test it, but it should be a simple
343
- corollary of all this effort. If the original tensors fed into the function do
347
+ corollary of all this effort. If the original tensors fed into a function
344
348
have ` requires_grad=True ` , the tensors will track the gradients of the internal
345
- implementation and then the user could differentiate through the NumPy code.
349
+ implementation and then the user can differentiate through their NumPy code.
346
350
347
- ### Bindings to TorchDyamo
351
+ ### Bindings to TorchDynamo
348
352
349
- The bindings for NumPy at the TorchDynamo level are currently being developed at [ #95849 ] ( https://github.com/pytorch/pytorch/pull/95849 ) .
353
+ The bindings for NumPy at the TorchDynamo level are currently being developed in
354
+ [ pytorch #95849 ] ( https://github.com/pytorch/pytorch/pull/95849 ) .
350
355
351
356
352
- ## Unresolved Questions
357
+ ## Unresolved questions
353
358
354
359
A question was left open in the initial discussion. Should the module
355
360
` torch_np ` be publicly exposed as ` torch.numpy ` or not?
@@ -369,7 +374,7 @@ A few arguments in favor of making it public:
369
374
A few arguments against:
370
375
* The compat introduces a number of type conversions that may produce somewhat
371
376
slow code when used in eager mode.
372
- * [ Note] Keeping this in mind, we tried to use in the implementations as few
373
- operators as possible , to make it reasonably fast in eager mode.
377
+ * [ Note] Keeping this in mind, we tried to use as few operators as possible,
378
+ in the implementation , to make it reasonably fast in eager mode.
374
379
* Exposing ` torch.numpy ` would create a less performant secondary entry point
375
380
to many of the functions in PyTorch. This could be a trap for new users.
0 commit comments