Skip to content

Commit e3c492b

Browse files
committed
Address Evgeni's review
1 parent e84e635 commit e3c492b

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

RFC.md

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ t_results = torch.empty(5, dtype=torch.float64)
7373
t_results[0] = result # store the result in a torch.Tensor
7474
```
7575

76-
Note that this code mixing NumPy and PyTorch already works, as `torch.Tensor`
77-
implements the `__array__` method. Now, the compatibility layer allows us to
78-
trace through it. In order to do that, there would be no necessary changes,
79-
other than simply ask `torch.compile` to trace through it:
76+
Note that this code mixing NumPy and PyTorch already works in eager mode with
77+
CPU tensors, as `torch.Tensor` implements the `__array__` method. Now, the
78+
compatibility layer allows us to trace through it. In order to do that, there
79+
would be no necessary changes, other than simply ask `torch.compile` to trace
80+
through it:
8081

8182
```python
8283
@compile
@@ -89,7 +90,7 @@ def fn(x, y):
8990
The two main ideas driving the design of this compatibility layer are the following:
9091

9192
1. The behavior of the layer should be as close to that of NumPy as possible
92-
2. The layer follows NumPy master
93+
2. The layer follows the most recent NumPy release
9394

9495
The following design decisions follow from these:
9596

@@ -129,8 +130,8 @@ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
129130
array([2, 4, 6])
130131
```
131132

132-
but we don't expect these to pose a big issue in practice. Note that in this
133-
implementation `torch_np.int32(2)` would return the same as `torch_np.asarray(2)`.
133+
but we don't expect these to pose a big issue in practice. Note that in the
134+
proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`.
134135

135136
**Type promotion**. Another not-so-well-known fact of NumPy's cast system is
136137
that it is data-dependent. Python scalars can be used in pretty much any NumPy
@@ -161,15 +162,23 @@ np.result_type(np.int8, int64_0d_array) == np.int8
161162
**Versioning**. It should be clear from the previous points that NumPy has a
162163
fair amount of questionable and legacy pain points. It is for this reason that
163164
we decided that rather than fighting these, we would declare that the compat
164-
layer follows the behavior of Numpy's master (even, in some cases, of NumPy
165-
2.0). Given the stability of NumPy's API and how battle-tested its main
166-
functions are, we do not expect this to become a big maintenance burden. If
167-
anything, it should make our lives easier, as some parts of NumPy will soon be
168-
simplified, saving us the pain of having to implement all the pre-existing
165+
layer follows the behavior of Numpy's most recent release (even, in some cases,
166+
of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its
167+
main functions are, we do not expect this to become a big maintenance burden.
168+
If anything, it should make our lives easier, as some parts of NumPy will soon
169+
be simplified, saving us the pain of having to implement all the pre-existing
169170
corner-cases.
170171

171172
For reference NumPy 2.0 is expected to land at the end of this year.
172173

174+
**Randomness**. PyTorch and NumPy use different random number generation methods.
175+
In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html)
176+
with a `Generator` object which has sampling methods on it. The current compat.
177+
layer does not implement this new API, as the default bit generator in NumPy is a
178+
`PCG64`, while on PyTorch we use a `MT19937` on CPU and a `Philox`. From this, it
179+
follows that this API will not give any reproducibility guarantees when it comes
180+
to randomness.
181+
173182

174183
## The `torch_np` module
175184

@@ -188,7 +197,7 @@ We put together a list of things that are out of the scope of this project in th
188197
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
189198

190199
For the bulk of the functions, we started by prioritizing the most common
191-
operations. Then, when bringing tests from the NumPy test suit, we would triage
200+
operations. Then, when bringing tests from the NumPy test suite, we would triage
192201
and prioritize how important was to fix each failure we found. Iterating this
193202
process, we ended up with a small list of differences between the NumPy and the
194203
PyTorch API which we prioritized by hand. That list and the prioritization
@@ -201,7 +210,7 @@ We discuss these in the section [unresolved questions](#unresolved-questions).
201210

202211
### Annotation-based preprocessing
203212

204-
NumPy accepts virtually anything that smells like an array as an input
213+
NumPy accepts virtually anything that smells like an array as an input.
205214

206215
```python
207216
>>> np.add(1, 3)
@@ -212,6 +221,7 @@ array([6., 7., 8.])
212221
array([1, 2, 3, 4, 5, 6])
213222
```
214223

224+
NumPy calls all these objects `array_like` objects.
215225
To implement NumPy in terms of PyTorch, for any operation we would need to map
216226
inputs into tensors, perform the operations, and then wrap the tensor into
217227
a `torch_np.ndarray` (more on this class later).
@@ -248,7 +258,8 @@ def diag(v: ArrayLike, k=0):
248258

249259
Then, we wrap these Python-land functions with a `normalizer` decorator and
250260
expose them in the `torch_np` module. This decorator is in charge of gathering
251-
all the inputs at runtime and normalizing them according to their annotations.
261+
all the inputs at runtime and normalizing them (i.e., converting `torch_np`
262+
objects to PyTorch counterparts) according to their annotations.
252263

253264
We currently have four annotations (and small variations of them):
254265
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a

0 commit comments

Comments
 (0)