@@ -85,6 +85,10 @@ def fn(x, y):
85
85
return np.multiply(x, y).sum()
86
86
```
87
87
88
+ Then, TorchDynamo would will cast ` x ` and ` y ` to our internal implementation of ` ndarray ` ,
89
+ and will dispatch ` np.multiply ` and ` sum ` to our implementations in terms of ` torch `
90
+ functions effectively turning this function into a pure PyTorch function.
91
+
88
92
### Design decisions
89
93
90
94
The main ideas driving the design of this compatibility layer are the following:
@@ -112,6 +116,8 @@ codebases from NumPy to JAX is the default dtype changing from `float64` to
112
116
[ JAX's shap edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
113
117
Following the spirit of making everything match NumPy by default, we choose the
114
118
NumPy defaults whenever the ` dtype ` was not made explicit in a factory function.
119
+ We also provide a function ` set_default_dtype ` that allows to change this behavior
120
+ dynamically.
115
121
116
122
** NumPy scalars** . NumPy's type system is tricky. At first sight, it looks
117
123
like PyTorch's, but with few more dtypes like ` np.uint16 ` or ` np.longdouble ` .
@@ -123,7 +129,8 @@ return just one element. NumPy scalars do not play particularly well with
123
129
computations on devices like GPUs, as they live on CPU. Implementing NumPy
124
130
scalars would mean that we need to synchronize after every ` sum() ` call, which
125
131
would be terrible performance-wise. In this implementation, we choose to represent
126
- NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
132
+ NumPy scalars as 0-D arrays. This may cause small divergences in some cases. For example,
133
+ consider the following NumPy behavior:
127
134
128
135
``` python
129
136
>> > np.int32(2 ) * [1 , 2 , 3 ] # scalar decays to a python int
@@ -133,7 +140,7 @@ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like
133
140
array([2 , 4 , 6 ])
134
141
```
135
142
136
- but we don't expect these to pose a big issue in practice. Note that in the
143
+ We don't expect these to pose a big issue in practice. Note that in the
137
144
proposed implementation ` np.int32(2) ` would return the same as ` np.asarray(2) ` .
138
145
In general, we try to avoid unnecessary graph breaks whenever we can. For
139
146
example, we may choose to return a tensor of shape ` (2, *) ` rather than a list
@@ -151,7 +158,7 @@ array([128], dtype=int8)
151
158
>> > np.asarray([1 ], dtype = np.int8) + 128
152
159
array([129 ], dtype = int16)
153
160
```
154
- This dependent type promotion will be deprecated NumPy 2.0, and will be
161
+ This data- dependent type promotion will be deprecated NumPy 2.0, and will be
155
162
replaced with [ NEP 50] ( https://numpy.org/neps/nep-0050-scalar-promotion.html ) .
156
163
For simplicity and to be forward-looking, we chose to implement the
157
164
type promotion behaviour proposed in NEP 50, which is much closer to that of
@@ -270,7 +277,8 @@ objects to PyTorch counterparts) according to their annotations.
270
277
We currently have four annotations (and small variations of them):
271
278
- ` ArrayLike ` : The input can be a ` torch_np.array ` , a list of lists, a
272
279
scalar, or anything that NumPy would accept. It returns a ` torch.Tensor ` .
273
- - ` DTypeLike ` : Takes a ` torch_np ` dtype and returns a PyTorch dtype.
280
+ - ` DTypeLike ` : Takes a ` torch_np ` dtype, and any other object that Numpy dtypes
281
+ accept (strings, typecodes...) and returns a PyTorch dtype.
274
282
- ` AxisLike ` : Takes anything that can be accepted as an axis (e.g. a tuple or
275
283
an ` ndarray ` ) and returns a tuple.
276
284
- ` OutArray ` : Asserts that the input is a ` torch_np.ndarray ` . This is used
@@ -302,18 +310,19 @@ This creates a circular dependency which we break with a local import.
302
310
### Testing
303
311
304
312
The testing of the framework was done via ~~ copying~~ vendoring tests from the
305
- NumPy test suit . Then, we would replace the NumPy imports with ` torch_np `
313
+ NumPy test suite . Then, we would replace the NumPy imports with ` torch_np `
306
314
imports. The failures on these tests were then triaged and discussed the
307
315
priority of fixing each of them.
308
316
309
- In the (near) future, we plan to get some real world examples and run them
310
- through the library, to test its coverage and correctness.
317
+ In the end, to have a last check that this tool was sound, we pulled five
318
+ examples of NumPy code from different sources and we run it with this library.
319
+ We were able to successfully the five examples successfully with close to no code changes.
320
+ You can read about these in the [ README] ( https://github.com/Quansight-Labs/numpy_pytorch_interop ) .
311
321
312
322
### Limitations
313
323
314
324
A number of known limitations are tracked in the second part of the
315
325
[ OP of this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73 ) .
316
- There are some more in [ this issue] ( https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86 ) .
317
326
When landing this RFC, we will create a comprehensive document with the differences
318
327
between NumPy and ` torch_np ` .
319
328
0 commit comments