Skip to content

Commit e84e635

Browse files
committed
General improvements
1 parent 7a5b98c commit e84e635

File tree

1 file changed

+89
-89
lines changed

1 file changed

+89
-89
lines changed

RFC.md

Lines changed: 89 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@ The this project has a main goal as per the
1616
[initial design document](https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8):
1717
1. Make TorchDynamo understand NumPy calls
1818

19-
The work is being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/).
19+
The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/).
2020

2121

2222
## Motivation
2323

24-
### An introductory example
25-
26-
Let's start with some examples.
24+
### Introductory examples
2725

2826
Consider the following snippet:
2927
```python
@@ -46,8 +44,8 @@ z = torch.matmul(x, y)
4644
w = z.sum()
4745
```
4846

49-
Here we already see a couple differences between NumPy and PyTorch. The most
50-
obvious one is that the default dtype in NumPy is `float64` rather than
47+
Here, we can already spot a couple differences between NumPy and PyTorch.
48+
The most obvious one is that the default dtype in NumPy is `float64` rather than
5149
`float32`. The less obvious is very sneakily hiding in the last line.
5250

5351
```python
@@ -57,10 +55,10 @@ obvious one is that the default dtype in NumPy is `float64` rather than
5755

5856
Reductions and similar operations in NumPy return the infamous NumPy scalars.
5957
We'll discuss these and other NumPy quirks and how we dealt with them in the
60-
sequel.
58+
[design decision section](#design-decisions).
6159

62-
As expected, this layer also allows for combining NumPy code and PyTorch code.
6360

61+
Let's now have a look at a toy example of how this layer would be used.
6462
```python
6563
import torch
6664
import numpy as np
@@ -75,41 +73,32 @@ t_results = torch.empty(5, dtype=torch.float64)
7573
t_results[0] = result # store the result in a torch.Tensor
7674
```
7775

78-
This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements
79-
the `__array__` method. For it to work manually with the compatibility layer, we would
80-
need to wrap and unwrap the inputs / outputs. This could be done modifying `fn`
81-
as
82-
83-
```python
84-
def fn(x, y):
85-
x = np.asarray(x)
86-
y = np.asarray(y)
87-
ret = np.multiply(x, y).sum()
88-
return ret.tensor.numpy()
89-
```
76+
Note that this code mixing NumPy and PyTorch already works, as `torch.Tensor`
77+
implements the `__array__` method. Now, the compatibility layer allows us to
78+
trace through it. In order to do that, there would be no necessary changes,
79+
other than simply ask `torch.compile` to trace through it:
9080

91-
This process would be done automatically by TorchDynamo, so we would simply need to write
9281
```python
9382
@compile
9483
def fn(x, y):
9584
return np.multiply(x, y).sum()
9685
```
9786

98-
### The observable behavior
87+
### Design decisions
9988

100-
The two main idea driving the design of this compatibility layer were the following:
89+
The two main ideas driving the design of this compatibility layer are the following:
10190

10291
1. The behavior of the layer should be as close to that of NumPy as possible
10392
2. The layer follows NumPy master
10493

10594
The following design decisions follow from these:
10695

107-
**Default dtypes**. One of the issues that most often user when moving their
108-
codebase from NumPy to JAX was the default dtype changing from `float64` to
109-
`float32`. So much so, that this is one noted as one of
96+
**Default dtypes**. One of the most common issues that bites people when migrating their
97+
codebases from NumPy to JAX is the default dtype changing from `float64` to
98+
`float32`. So much so that this is noted as one of
11099
[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
111100
Following the spirit of making everything match NumPy by default, we choose the
112-
NumPy defaults whenever the `dtype` was not chosen in a factory function.
101+
NumPy defaults whenever the `dtype` was not made explicit in a factory function.
113102

114103
**TODO(Lezcano)**: I just realized that we do not have a clean way to change
115104
the default dtype of `torch_np` to those from PyTorch. We should implement
@@ -121,23 +110,34 @@ to be able to set their own int/float/complex defaults.
121110
use them anywhere else -> Check
122111

123112
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks
124-
quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or
125-
`np.longdouble`. Upon closer inspection, one finds that it also has
113+
like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`.
114+
Upon closer inspection, one finds that it also has
126115
[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects.
127116
NumPy scalars are similar to Python scalars but with a set width. NumPy scalars
128117
are NumPy's preferred return class for reductions and other operations that
129118
return just one element. NumPy scalars do not play particularly well with
130119
computations on devices like GPUs, as they live on CPU. Implementing NumPy
131120
scalars would mean that we need to synchronize after every `sum()` call, which
132-
is less-than-good. Instead, whenever a NumPy scalar would be returned, we will
133-
return a 0-D tensor, as PyTorch does.
121+
would be terrible performance-wise. In this implementation, we choose to represent
122+
NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
123+
124+
```python
125+
>>> np.int32(2) * [1, 2, 3] # scalar decays to a python int
126+
[1, 2, 3, 1, 2, 3]
127+
128+
>>> np.asarray(2) * [1, 2, 3] # zero-dim array is an array-like
129+
array([2, 4, 6])
130+
```
131+
132+
but we don't expect these to pose a big issue in practice. Note that in this
133+
implementation `torch_np.int32(2)` would return the same as `torch_np.asarray(2)`.
134134

135135
**Type promotion**. Another not-so-well-known fact of NumPy's cast system is
136136
that it is data-dependent. Python scalars can be used in pretty much any NumPy
137137
operation, being able to call any operation that accepts a 0-D array with a
138138
Python scalar. If you provide an operation with a Python scalar, these will be
139-
casted to the smallest dtype that can represent them, and then they will
140-
participate in type promotion, allowing for some rather interesting behaviour
139+
casted to the smallest dtype they can be represented in, and then, they will
140+
participate in type promotion. This allows for for some rather interesting behaviour
141141
```python
142142
>>> np.asarray([1], dtype=np.int8) + 127
143143
array([128], dtype=int8)
@@ -146,64 +146,63 @@ array([129], dtype=int16)
146146
```
147147
This dependent type promotion will be deprecated NumPy 2.0, and will be
148148
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
149-
As such, to be forward-looking and for simplicity, we chose to implement the
149+
For simplicity and to be forward-looking, we chose to implement the
150150
type promotion behaviour proposed in NEP 50, which is much closer to that of
151151
Pytorch.
152152

153153
Note that the decision of going with NEP 50 complements the previous one of
154154
returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not
155-
participate in type promotion in NumPy (but will do in NumPy 2.0):
155+
participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50):
156156
```python
157157
int64_0d_array = np.array(1, dtype=np.int64)
158158
np.result_type(np.int8, int64_0d_array) == np.int8
159159
```
160160

161161
**Versioning**. It should be clear from the previous points that NumPy has a
162-
fair amount of questionable and legacy pain points. As such, we decided that
163-
rather than trying to fight these, we would declare that the compat layer
164-
follows the behavior of Numpy's master. Given the stability of NumPy's API and
165-
how battle-tested its main functions are, we do not expect this to become a big
166-
maintenance burden. If anything, it should make our lives easier, as some parts
167-
of NumPy will soon be simplified and we will not need to implement them, as
168-
described above.
162+
fair amount of questionable and legacy pain points. It is for this reason that
163+
we decided that rather than fighting these, we would declare that the compat
164+
layer follows the behavior of Numpy's master (even, in some cases, of NumPy
165+
2.0). Given the stability of NumPy's API and how battle-tested its main
166+
functions are, we do not expect this to become a big maintenance burden. If
167+
anything, it should make our lives easier, as some parts of NumPy will soon be
168+
simplified, saving us the pain of having to implement all the pre-existing
169+
corner-cases.
170+
171+
For reference NumPy 2.0 is expected to land at the end of this year.
169172

170173

171174
## The `torch_np` module
172175

173176
The bulk of the work went into implementing a system that allows us to
174177
implement NumPy operations in terms of those of PyTorch. The main design goals
175-
were
178+
here were
176179

177180
1. Implement *most* of NumPy's API
178181
2. Preserve NumPy semantics as much as possible
179182

180183
We say *most* of NumPy's API, because NumPy's API is not only massive, but also
181184
there are parts of it which cannot be implemented in PyTorch. For example,
182185
NumPy has support for arrays of string, datetime, structured and other dtypes.
183-
Negative strides are other example of a feature that is just out of the scope.
186+
Negative strides are other example of a feature that is just not supported in PyTorch.
184187
We put together a list of things that are out of the scope of this project in the
185188
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
186189

187-
For the bulk of the functions, we started by prioritizing most common
188-
operations. Then, when bringing tests from the NumPy test suit and running
189-
them, we would triage and prioritize how important was to fix each failure we
190-
found. Iterating this process, we ended up with a small list of differences
191-
between the NumPy and the PyTorch API which we sorted out by hand and finished
192-
implementing. That list and the prioritization discussion can be found in
193-
[the first few posts of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87).
194-
195-
The second point of preserving NumPy semantics as much as possible will be used
196-
in the sequel to discuss some points like the default dtypes that are used
197-
throughout the implementation.
190+
For the bulk of the functions, we started by prioritizing the most common
191+
operations. Then, when bringing tests from the NumPy test suit, we would triage
192+
and prioritize how important was to fix each failure we found. Iterating this
193+
process, we ended up with a small list of differences between the NumPy and the
194+
PyTorch API which we prioritized by hand. That list and the prioritization
195+
discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87).
198196

199197
**Visibility of the module** For simplicity, this RFC assumes that the
200198
`torch_np` module will not be public, as the decision for it to be made public
201-
was met with different opinions. We discuss these in the "Unresolved Questions"
202-
section.
199+
was met with different opinions.
200+
We discuss these in the section [unresolved questions](#unresolved-questions).
203201

204202
### Annotation-based preprocessing
205203

206-
NumPy accepts virtually anything that smells like an array as input to its operators
204+
NumPy accepts virtually anything that smells like an array as an input
205+
207206
```python
208207
>>> np.add(1, 3)
209208
4
@@ -213,8 +212,8 @@ array([6., 7., 8.])
213212
array([1, 2, 3, 4, 5, 6])
214213
```
215214

216-
To implement NumPy in terms of PyTorch, for any operation we would need to put
217-
the inputs into tensors, perform the operations, and then wrap the tensor into
215+
To implement NumPy in terms of PyTorch, for any operation we would need to map
216+
inputs into tensors, perform the operations, and then wrap the tensor into
218217
a `torch_np.ndarray` (more on this class later).
219218

220219
To avoid all this code repetition, we implement the functions in two steps.
@@ -223,31 +222,33 @@ First, we implement functions with the NumPy signature, but assuming that in
223222
place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc)
224223
they simply accept `torch.Tensor` and PyTorch-land objects and return
225224
`torch.Tensor`s. For example, we would implement `np.diag` as
225+
226226
```python
227227
def diag(v, k=0):
228228
return torch.diag(v, k)
229229
```
230+
230231
In this layer, if a NumPy function is composite (calls other NumPy functions
231232
internally), we can simply vendor its implementation, and have it call our
232233
PyTorch-land implementations of these functions. In other words, at this level,
233-
functions are composable, as any set of functions implemented purely in
234-
PyTorch. All these implementations are internal, and are not meant to be seen
235-
or used by the final user.
234+
functions are composable, as they are simply regular PyTorch functions.
235+
All these implementations are internal, and are not meant to be seen or used
236+
by the final user.
236237

237238
The second step is then done via type annotations and a decorator. Each type
238-
annotation has then a map NumPy-land -> PyTorch-land associated, that maps the
239-
set of inputs accepted by NumPy for that argument into a PyTorch-land object
240-
(think a `torch.Tensor` or a PyTorch dtype). For example, for `np.diag` we
241-
would write
239+
annotation has an associated function from NumPy-land into PyTorch-land. This
240+
function converts the set of inputs accepted by NumPy for that argument into a
241+
PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example,
242+
for `np.diag` we would write
243+
242244
```python
243245
def diag(v: ArrayLike, k=0):
244246
return torch.diag(v, k)
245247
```
246248

247-
Then, we would wrap these Python-land functions in a `normalizer` decorator and
248-
expose them in the public `torch.np` module. This decorator is in charge of
249-
gathering all the inputs at runtime and normalizing them according to their
250-
annotations.
249+
Then, we wrap these Python-land functions with a `normalizer` decorator and
250+
expose them in the `torch_np` module. This decorator is in charge of gathering
251+
all the inputs at runtime and normalizing them according to their annotations.
251252

252253
We currently have four annotations (and small variations of them):
253254
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a
@@ -258,9 +259,9 @@ We currently have four annotations (and small variations of them):
258259
- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used
259260
to implement the `out` arg.
260261

261-
Note that none of the code here makes use of NumPy. We are writing
262-
`torch_np.ndarray` above to make more explicit our intents, but there
263-
shouldn't be any ambiguity here.
262+
Note that none of the code in this implementation makes use of NumPy. We are
263+
writing `torch_np.ndarray` above to make more explicit our intents, but there
264+
shouldn't be any ambiguity.
264265

265266
**OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]`
266267
**OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that
@@ -271,30 +272,28 @@ implementation, or mark explicitly those that we don't use.
271272

272273
**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a
273274
keyword-only argument. It is for this reason that, in PrimTorch, we were able
274-
to implement it as
275-
[a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282).
275+
to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282).
276276
This is not the case in NumPy. In NumPy `out` is a positional arg that is often
277277
interleaved with other parameters. This is the reason why we use the `OutArray`
278-
label to mark these. We then implement the `out` semantics in the `@normalizer`
278+
annotation to mark these. We then implement the `out` semantics in the `@normalizer`
279279
wrapper in a generic way.
280280

281281
**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two
282282
sets of functions that are particularly regular. For these functions, we
283-
implement (some of) their args in a generic way. We then simply forward the
284-
computations to PyTorch, perhaps working around some PyTorch limitations.
285-
286-
### The `ndarray` class
283+
implement their args in a generic way as a preprocessing or postprocessing.
287284

288-
Once we have all the free functions implemented, implementing an `ndarray`
289-
class is rather simple. We simply register all the free functions as methods or
290-
dunder methods appropriately. We also forward the properties to the properties
291-
within the PyTorch tensor and we are done.
285+
**The ndarray class** Once we have all the free functions implemented as
286+
functions form `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the
287+
methods from the `ndarray` class is rather simple. We simply register all the
288+
free functions as methods or dunder methods appropriately. We also forward the
289+
properties to the properties within the PyTorch tensor and we are done.
290+
This creates a circular dependency which we break with a local import.
292291

293292
### Testing
294293

295294
The testing of the framework was done via ~~copying~~ vendoring tests from the
296-
NumPy test suit. Then, we would replace the NumPy imports for imports with
297-
`torch_np`. The failures on these tests were then triaged and discussed the
295+
NumPy test suit. Then, we would replace the NumPy imports with `torch_np`
296+
imports. The failures on these tests were then triaged and discussed the
298297
priority of fixing each of them.
299298

300299
In the (near) future, we plan to get some real world examples and run them
@@ -305,7 +304,7 @@ through the library, to test its coverage and correctness.
305304
A number of known limitations are tracked in the second part of the
306305
[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
307306
There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86).
308-
When landing all this, we will create a comprehensive document with the differences
307+
When landing this RFC, we will create a comprehensive document with the differences
309308
between NumPy and `torch_np`.
310309

311310
### Beyond Plain NumPy
@@ -332,7 +331,8 @@ The bindings for NumPy at the TorchDynamo level are currently being developed at
332331

333332
## Unresolved Questions
334333

335-
A question was left open in the initial discussion. Should the module `torch_np` be publicly exposed as `torch.numpy` or not?
334+
A question was left open in the initial discussion. Should the module
335+
`torch_np` be publicly exposed as `torch.numpy` or not?
336336

337337
A few arguments in favor of making it public:
338338
* People could use it in their NumPy programs just by changing the import to

0 commit comments

Comments
 (0)