Skip to content

Commit 44ee780

Browse files
committed
Address review comments
1 parent b85a999 commit 44ee780

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

RFC.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def fn(x, y):
8585
return np.multiply(x, y).sum()
8686
```
8787

88+
Then, TorchDynamo would will cast `x` and `y` to our internal implementation of `ndarray`,
89+
and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch`
90+
functions effectively turning this function into a pure PyTorch function.
91+
8892
### Design decisions
8993

9094
The main ideas driving the design of this compatibility layer are the following:
@@ -112,6 +116,8 @@ codebases from NumPy to JAX is the default dtype changing from `float64` to
112116
[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
113117
Following the spirit of making everything match NumPy by default, we choose the
114118
NumPy defaults whenever the `dtype` was not made explicit in a factory function.
119+
We also provide a function `set_default_dtype` that allows to change this behavior
120+
dynamically.
115121

116122
**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks
117123
like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`.
@@ -123,7 +129,8 @@ return just one element. NumPy scalars do not play particularly well with
123129
computations on devices like GPUs, as they live on CPU. Implementing NumPy
124130
scalars would mean that we need to synchronize after every `sum()` call, which
125131
would be terrible performance-wise. In this implementation, we choose to represent
126-
NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
132+
NumPy scalars as 0-D arrays. This may cause small divergences in some cases. For example,
133+
consider the following NumPy behavior:
127134

128135
```python
129136
>>> np.int32(2) * [1, 2, 3] # scalar decays to a python int
@@ -133,7 +140,7 @@ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
133140
array([2, 4, 6])
134141
```
135142

136-
but we don't expect these to pose a big issue in practice. Note that in the
143+
We don't expect these to pose a big issue in practice. Note that in the
137144
proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`.
138145
In general, we try to avoid unnecessary graph breaks whenever we can. For
139146
example, we may choose to return a tensor of shape `(2, *)` rather than a list
@@ -151,7 +158,7 @@ array([128], dtype=int8)
151158
>>> np.asarray([1], dtype=np.int8) + 128
152159
array([129], dtype=int16)
153160
```
154-
This dependent type promotion will be deprecated NumPy 2.0, and will be
161+
This data-dependent type promotion will be deprecated NumPy 2.0, and will be
155162
replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
156163
For simplicity and to be forward-looking, we chose to implement the
157164
type promotion behaviour proposed in NEP 50, which is much closer to that of
@@ -270,7 +277,8 @@ objects to PyTorch counterparts) according to their annotations.
270277
We currently have four annotations (and small variations of them):
271278
- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a
272279
scalar, or anything that NumPy would accept. It returns a `torch.Tensor`.
273-
- `DTypeLike`: Takes a `torch_np` dtype and returns a PyTorch dtype.
280+
- `DTypeLike`: Takes a `torch_np` dtype, and any other object that Numpy dtypes
281+
accept (strings, typecodes...) and returns a PyTorch dtype.
274282
- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or
275283
an `ndarray`) and returns a tuple.
276284
- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used
@@ -302,18 +310,19 @@ This creates a circular dependency which we break with a local import.
302310
### Testing
303311

304312
The testing of the framework was done via ~~copying~~ vendoring tests from the
305-
NumPy test suit. Then, we would replace the NumPy imports with `torch_np`
313+
NumPy test suite. Then, we would replace the NumPy imports with `torch_np`
306314
imports. The failures on these tests were then triaged and discussed the
307315
priority of fixing each of them.
308316

309-
In the (near) future, we plan to get some real world examples and run them
310-
through the library, to test its coverage and correctness.
317+
In the end, to have a last check that this tool was sound, we pulled five
318+
examples of NumPy code from different sources and we run it with this library.
319+
We were able to successfully the five examples successfully with close to no code changes.
320+
You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop).
311321

312322
### Limitations
313323

314324
A number of known limitations are tracked in the second part of the
315325
[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73).
316-
There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86).
317326
When landing this RFC, we will create a comprehensive document with the differences
318327
between NumPy and `torch_np`.
319328

0 commit comments

Comments
 (0)