Skip to content

RFC #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 28, 2023
41 changes: 26 additions & 15 deletions RFC.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ t_results = torch.empty(5, dtype=torch.float64)
t_results[0] = result # store the result in a torch.Tensor
```

Note that this code mixing NumPy and PyTorch already works, as `torch.Tensor`
implements the `__array__` method. Now, the compatibility layer allows us to
trace through it. In order to do that, there would be no necessary changes,
other than simply ask `torch.compile` to trace through it:
Note that this code mixing NumPy and PyTorch already works in eager mode with
CPU tensors, as `torch.Tensor` implements the `__array__` method. Now, the
compatibility layer allows us to trace through it. In order to do that, there
would be no necessary changes, other than simply ask `torch.compile` to trace
through it:

```python
@compile
Expand All @@ -89,7 +90,7 @@ def fn(x, y):
The two main ideas driving the design of this compatibility layer are the following:

1. The behavior of the layer should be as close to that of NumPy as possible
2. The layer follows NumPy master
2. The layer follows the most recent NumPy release

The following design decisions follow from these:

Expand Down Expand Up @@ -129,8 +130,8 @@ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
array([2, 4, 6])
```

but we don't expect these to pose a big issue in practice. Note that in this
implementation `torch_np.int32(2)` would return the same as `torch_np.asarray(2)`.
but we don't expect these to pose a big issue in practice. Note that in the
proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`.

**Type promotion**. Another not-so-well-known fact of NumPy's cast system is
that it is data-dependent. Python scalars can be used in pretty much any NumPy
Expand Down Expand Up @@ -161,15 +162,23 @@ np.result_type(np.int8, int64_0d_array) == np.int8
**Versioning**. It should be clear from the previous points that NumPy has a
fair amount of questionable and legacy pain points. It is for this reason that
we decided that rather than fighting these, we would declare that the compat
layer follows the behavior of Numpy's master (even, in some cases, of NumPy
2.0). Given the stability of NumPy's API and how battle-tested its main
functions are, we do not expect this to become a big maintenance burden. If
anything, it should make our lives easier, as some parts of NumPy will soon be
simplified, saving us the pain of having to implement all the pre-existing
layer follows the behavior of Numpy's most recent release (even, in some cases,
of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its
main functions are, we do not expect this to become a big maintenance burden.
If anything, it should make our lives easier, as some parts of NumPy will soon
be simplified, saving us the pain of having to implement all the pre-existing
corner-cases.

For reference NumPy 2.0 is expected to land at the end of this year.

**Randomness**. PyTorch and NumPy use different random number generation methods.
In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html)
with a `Generator` object which has sampling methods on it. The current compat.
layer does not implement this new API, as the default bit generator in NumPy is a
`PCG64`, while on PyTorch we use a `MT19937` on CPU and a `Philox`. From this, it
follows that this API will not give any reproducibility guarantees when it comes
to randomness.


## The `torch_np` module

Expand All @@ -188,7 +197,7 @@ We put together a list of things that are out of the scope of this project in th
[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).

For the bulk of the functions, we started by prioritizing the most common
operations. Then, when bringing tests from the NumPy test suit, we would triage
operations. Then, when bringing tests from the NumPy test suite, we would triage
and prioritize how important was to fix each failure we found. Iterating this
process, we ended up with a small list of differences between the NumPy and the
PyTorch API which we prioritized by hand. That list and the prioritization
Expand All @@ -201,7 +210,7 @@ We discuss these in the section [unresolved questions](#unresolved-questions).

### Annotation-based preprocessing

NumPy accepts virtually anything that smells like an array as an input
NumPy accepts virtually anything that smells like an array as an input.

```python
>>> np.add(1, 3)
Expand All @@ -212,6 +221,7 @@ array([6., 7., 8.])
array([1, 2, 3, 4, 5, 6])
```

NumPy calls all these objects `array_like` objects.
To implement NumPy in terms of PyTorch, for any operation we would need to map
inputs into tensors, perform the operations, and then wrap the tensor into
a `torch_np.ndarray` (more on this class later).
Expand Down Expand Up @@ -248,7 +258,8 @@ def diag(v: ArrayLike, k=0):

Then, we wrap these Python-land functions with a `normalizer` decorator and
expose them in the `torch_np` module. This decorator is in charge of gathering
all the inputs at runtime and normalizing them according to their annotations.
all the inputs at runtime and normalizing them (i.e., converting `torch_np`
objects to PyTorch counterparts) according to their annotations.

We currently have four annotations (and small variations of them):
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a
Expand Down