Skip to content

Commit 0a75be4

Browse files
committed
A full copy-edit and fixes for a few minor inaccuracies
1 parent 332a2d0 commit 0a75be4

File tree

1 file changed

+71
-66
lines changed

1 file changed

+71
-66
lines changed

RFC.md

Lines changed: 71 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,24 @@ def fn(x, y):
9292

9393
Then, TorchDynamo would will cast `x` and `y` to our internal implementation of `ndarray`,
9494
and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch`
95-
functions effectively turning this function into a pure PyTorch function.
95+
functions, effectively turning this function into a pure PyTorch function.
9696

9797
### Design decisions
9898

9999
The main ideas driving the design of this compatibility layer are the following:
100100

101-
1. The goal is to transform valid NumPy programs into their equivalent PyTorch
101+
1. The goal is to transform valid NumPy and mixed PyTorch-NumPy programs into
102+
their equivalent PyTorch-only execution.
102103
2. The behavior of the layer should be as close to that of NumPy as possible
103104
3. The layer follows the most recent NumPy release
104105

105106
The following design decisions follow from these:
106107

107-
**A superset of NumPy**. Same as PyTorch has spotty support for `float16` on
108-
CPU, and less-than-good support for `complex32`, NumPy has a number of
109-
well-known edge-cases. The decision of translating just valid NumPy programs,
110-
often allows us to implement a superset of the functionality of NumPy with more
111-
predictable and consistent behavior than NumPy itself.
108+
**A superset of NumPy**. NumPy has a number of well-known edge-cases (as does
109+
PyTorch, like spotty support for `float16` on CPU and `complex32` in general).
110+
The decision to translate only valid NumPy programs, often allows us to
111+
implement a superset of the functionality of NumPy with more predictable and
112+
consistent behavior than NumPy itself has.
112113

113114
**Exceptions may be different**. We avoid entirely modelling the exception
114115
system in NumPy. As seen in the implementation of PrimTorch, modelling the
@@ -118,19 +119,20 @@ and we choose not to offer any guarantee here.
118119
**Default dtypes**. One of the most common issues that bites people when migrating their
119120
codebases from NumPy to JAX is the default dtype changing from `float64` to
120121
`float32`. So much so that this is noted as one of
121-
[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
122+
[JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
122123
Following the spirit of making everything match NumPy by default, we choose the
123-
NumPy defaults whenever the `dtype` was not made explicit in a factory function.
124+
NumPy default dtype whenever the `dtype` was not made explicit in a factory function.
124125
We also provide a function `set_default_dtype` that allows to change this behavior
125126
dynamically.
126127

127128
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks
128129
like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`.
129130
Upon closer inspection, one finds that it also has
130131
[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects.
131-
NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
132-
are NumPy's preferred return class for reductions and other operations that
133-
return just one element. NumPy scalars do not play particularly well with
132+
NumPy scalars are similar to Python scalars but with a fixed precision and
133+
array-like methods attached. NumPy scalars are NumPy's preferred return class
134+
for reductions and other operations that return just one element.
135+
NumPy scalars do not play particularly well with
134136
computations on devices like GPUs, as they live on CPU. Implementing NumPy
135137
scalars would mean that we need to synchronize after every `sum()` call, which
136138
would be terrible performance-wise. In this implementation, we choose to represent
@@ -149,25 +151,26 @@ We don't expect these to pose a big issue in practice. Note that in the
149151
proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`.
150152
In general, we try to avoid unnecessary graph breaks whenever we can. For
151153
example, we may choose to return a tensor of shape `(2, *)` rather than a list
152-
of pairs, to avoid unnecessary graph breaks.
154+
of pairs, to avoid a graph break.
153155

154-
**Type promotion**. Another not-so-well-known fact of NumPy's cast system is
155-
that it is data-dependent. Python scalars can be used in pretty much any NumPy
156+
**Type promotion**. Another not-so-well-known fact of NumPy's dtype system and casting rules
157+
is that it is data-dependent. Python scalars can be used in pretty much any NumPy
156158
operation, being able to call any operation that accepts a 0-D array with a
157159
Python scalar. If you provide an operation with a Python scalar, these will be
158-
casted to the smallest dtype they can be represented in, and then, they will
160+
cast to the smallest dtype they can be represented in, and only then will they
159161
participate in type promotion. This allows for for some rather interesting behaviour
160162
```python
161163
>>> np.asarray([1], dtype=np.int8) + 127
162164
array([128], dtype=int8)
163165
>>> np.asarray([1], dtype=np.int8) + 128
164166
array([129], dtype=int16)
165167
```
166-
This data-dependent type promotion will be deprecated NumPy 2.0, and will be
167-
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
168+
This data-dependent type promotion will be removed in NumPy 2.0 (planned for Dec'23), and will be
169+
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html)
170+
(already implemented in NumPy, it needs to be enabled via a private global switch now).
168171
For simplicity and to be forward-looking, we chose to implement the
169172
type promotion behaviour proposed in NEP 50, which is much closer to that of
170-
Pytorch.
173+
PyTorch.
171174

172175
Note that the decision of going with NEP 50 complements the previous one of
173176
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
@@ -178,24 +181,22 @@ np.result_type(np.int8, int64_0d_array) == np.int8
178181
```
179182

180183
**Versioning**. It should be clear from the previous points that NumPy has a
181-
fair amount of questionable and legacy pain points. It is for this reason that
184+
fair amount of questionable behavior and legacy pain points. It is for this reason that
182185
we decided that rather than fighting these, we would declare that the compat
183-
layer follows the behavior of Numpy's most recent release (even, in some cases,
186+
layer follows the behavior of NumPy's most recent release (even, in some cases,
184187
of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its
185188
main functions are, we do not expect this to become a big maintenance burden.
186189
If anything, it should make our lives easier, as some parts of NumPy will soon
187190
be simplified, saving us the pain of having to implement all the pre-existing
188191
corner-cases.
189192

190-
For reference NumPy 2.0 is expected to land at the end of this year.
191-
192193
**Randomness**. PyTorch and NumPy use different random number generation methods.
193194
In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html)
194-
with a `Generator` object which has sampling methods on it. The current compat.
195-
layer does not implement this new API, as the default bit generator in NumPy is a
196-
`PCG64`, while on PyTorch we use a `MT19937` on CPU and a `Philox`. From this, it
197-
follows that this API will not give any reproducibility guarantees when it comes
198-
to randomness.
195+
with a `Generator` object which has sampling methods on it. The current compat
196+
layer does not implement this new API, as the default bit generator in NumPy is
197+
`PCG64`, while on PyTorch we use `MT19937` on CPU and `Philox` on non-CPU devices.
198+
From this, it follows that this API will not give any reproducibility
199+
guarantees when it comes to randomness.
199200

200201

201202
## The `torch_np` module
@@ -210,20 +211,21 @@ here were
210211
We say *most* of NumPy's API, because NumPy's API is not only massive, but also
211212
there are parts of it which cannot be implemented in PyTorch. For example,
212213
NumPy has support for arrays of string, datetime, structured and other dtypes.
213-
Negative strides are other example of a feature that is just not supported in PyTorch.
214+
Negative strides are another example of a feature that is not supported in PyTorch.
214215
We put together a list of things that are out of the scope of this project in the
215216
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
216217

217218
For the bulk of the functions, we started by prioritizing the most common
218-
operations. Then, when bringing tests from the NumPy test suite, we would triage
219-
and prioritize how important was to fix each failure we found. Iterating this
220-
process, we ended up with a small list of differences between the NumPy and the
221-
PyTorch API which we prioritized by hand. That list and the prioritization
219+
operations. Then, when bringing tests from the NumPy test suite, we triaged
220+
and prioritized how important it was to fix each failure we found. Doing this
221+
iteratively, we ended up with a small list of differences between the NumPy and
222+
PyTorch APIs, which we prioritized by hand. That list and the prioritization
222223
discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87).
223224

224225
**Visibility of the module** For simplicity, this RFC assumes that the
225-
`torch_np` module will not be public, as the decision for it to be made public
226-
was met with different opinions.
226+
`torch_np` module will not be public, as the initial suggestion for it to be
227+
made public was met with mixed opinions. This topic can be revisited in the
228+
future if desired.
227229
We discuss these in the section [unresolved questions](#unresolved-questions).
228230

229231
### Annotation-based preprocessing
@@ -261,7 +263,7 @@ internally), we can simply vendor its implementation, and have it call our
261263
PyTorch-land implementations of these functions. In other words, at this level,
262264
functions are composable, as they are simply regular PyTorch functions.
263265
All these implementations are internal, and are not meant to be seen or used
264-
by the final user.
266+
by the end user.
265267

266268
The second step is then done via type annotations and a decorator. Each type
267269
annotation has an associated function from NumPy-land into PyTorch-land. This
@@ -287,41 +289,43 @@ We currently have four annotations (and small variations of them):
287289
- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or
288290
an `ndarray`) and returns a tuple.
289291
- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used
290-
to implement the `out` arg.
292+
to implement the `out` keyword.
291293

292294
Note that none of the code in this implementation makes use of NumPy. We are
293-
writing `torch_np.ndarray` above to make more explicit our intents, but there
295+
writing `torch_np.ndarray` above to make more explicit our intent, but there
294296
shouldn't be any ambiguity.
295297

296-
**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a
297-
keyword-only argument. It is for this reason that, in PrimTorch, we were able
298-
to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282).
299-
This is not the case in NumPy. In NumPy `out` is a positional arg that is often
300-
interleaved with other parameters. This is the reason why we use the `OutArray`
301-
annotation to mark these. We then implement the `out` semantics in the `@normalizer`
302-
wrapper in a generic way.
298+
**Implementing `out`**: In PyTorch, the `out` kwarg is a keyword-only argument.
299+
It is for this reason that, in PrimTorch, we were able to implement it as [a
300+
decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282).
301+
This is not the case in NumPy. In NumPy, `out` can be used both as a positional
302+
and a keyword argument, and is often interleaved with other parameters. This is
303+
the reason why we use the `OutArray` annotation to mark these. We then
304+
implement the `out` semantics in the `@normalizer` wrapper in a generic way.
303305

304306
**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two
305307
sets of functions that are particularly regular. For these functions, we
306-
implement their args in a generic way as a preprocessing or postprocessing.
308+
implement support for their arguments in a generic way as a preprocessing or
309+
postprocessing step.
307310

308-
**The ndarray class** Once we have all the free functions implemented as
309-
functions form `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the
311+
**The `ndarray` class** Once we have all the free functions implemented as
312+
functions from `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the
310313
methods from the `ndarray` class is rather simple. We simply register all the
311314
free functions as methods or dunder methods appropriately. We also forward the
312-
properties to the properties within the PyTorch tensor and we are done.
313-
This creates a circular dependency which we break with a local import.
315+
properties of `ndarray to the corresponding properties of `torch.Tensor` and we
316+
are done. This creates a circular dependency which we break with a local
317+
import.
314318

315319
### Testing
316320

317321
The testing of the framework was done via ~~copying~~ vendoring tests from the
318322
NumPy test suite. Then, we would replace the NumPy imports with `torch_np`
319-
imports. The failures on these tests were then triaged and discussed the
320-
priority of fixing each of them.
323+
imports. The failures on these tests were then triaged, and either fixed or marked
324+
`xfail` depending on our assessment of the priority of implementing a fix.
321325

322326
In the end, to have a last check that this tool was sound, we pulled five
323-
examples of NumPy code from different sources and we run it with this library.
324-
We were able to successfully the five examples successfully with close to no code changes.
327+
examples of NumPy code from different sources and ran it with this library (eager mode execution).
328+
We were able to run the five examples successfully with close to no code changes.
325329
You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop).
326330

327331
### Limitations
@@ -331,25 +335,26 @@ A number of known limitations are tracked in the second part of the
331335
When landing this RFC, we will create a comprehensive document with the differences
332336
between NumPy and `torch_np`.
333337

334-
### Beyond Plain NumPy
338+
### Beyond plain NumPy
335339

336-
**GPU**. The current implementation has just been implemented and tested on
337-
CPU. We expect GPU coverage to be as good as the coverage we have with CPU
338-
matching GPU. If the original tensors are on GPU, the whole execution should
339-
be performed on the GPU.
340+
**GPU**. The current implementation so far only been implemented and tested on
341+
CPU. We expect GPU coverage to be as good as the coverage we have with CPU-GPU
342+
matching tests in the PyTorch test suite. If the original tensors are on GPU,
343+
the execution should be performed fully on GPU.
340344

341345
**Gradients**. We have not tested gradient tracking either as we are still to
342346
find some good examples on which to test it, but it should be a simple
343-
corollary of all this effort. If the original tensors fed into the function do
347+
corollary of all this effort. If the original tensors fed into a function
344348
have `requires_grad=True`, the tensors will track the gradients of the internal
345-
implementation and then the user could differentiate through the NumPy code.
349+
implementation and then the user can differentiate through their NumPy code.
346350

347-
### Bindings to TorchDyamo
351+
### Bindings to TorchDynamo
348352

349-
The bindings for NumPy at the TorchDynamo level are currently being developed at [#95849](https://github.com/pytorch/pytorch/pull/95849).
353+
The bindings for NumPy at the TorchDynamo level are currently being developed in
354+
[pytorch#95849](https://github.com/pytorch/pytorch/pull/95849).
350355

351356

352-
## Unresolved Questions
357+
## Unresolved questions
353358

354359
A question was left open in the initial discussion. Should the module
355360
`torch_np` be publicly exposed as `torch.numpy` or not?
@@ -369,7 +374,7 @@ A few arguments in favor of making it public:
369374
A few arguments against:
370375
* The compat introduces a number of type conversions that may produce somewhat
371376
slow code when used in eager mode.
372-
* [Note] Keeping this in mind, we tried to use in the implementations as few
373-
operators as possible, to make it reasonably fast in eager mode.
377+
* [Note] Keeping this in mind, we tried to use as few operators as possible,
378+
in the implementation, to make it reasonably fast in eager mode.
374379
* Exposing `torch.numpy` would create a less performant secondary entry point
375380
to many of the functions in PyTorch. This could be a trap for new users.

0 commit comments

Comments
 (0)