|
| 1 | +# A PyTorch - NumPy compatibility layer |
| 2 | + |
| 3 | +**Authors:** |
| 4 | +* @ev-br |
| 5 | +* @lezcano |
| 6 | +* @rgommers |
| 7 | + |
| 8 | +## Summary |
| 9 | +This RFC describes a proposal for a translation layer from NumPy into PyTorch. |
| 10 | +In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, |
| 11 | +the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor` |
| 12 | +and PyTorch ops as backend. |
| 13 | + |
| 14 | +The main goal is: **make TorchDynamo understand NumPy calls**. |
| 15 | +This should enable an end user to combine code that uses the PyTorch API with |
| 16 | +code that uses the NumPy API, in a way that allows TorchDynamo to understand |
| 17 | +those function calls and build up an execution graph. To enable this, it is key |
| 18 | +that there is a translation layer from NumPy to PyTorch function calls, which |
| 19 | +TorchDynamo can use in order to build up its execution graph from PyTorch |
| 20 | +functions/primitives only. For niche functions in NumPy that don’t have a |
| 21 | +PyTorch equivalent, it’s okay to graph break and still call NumPy to execute |
| 22 | +the function call. |
| 23 | + |
| 24 | +The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). |
| 25 | + |
| 26 | + |
| 27 | +## Motivation |
| 28 | + |
| 29 | +### Introductory examples |
| 30 | + |
| 31 | +Consider the following snippet: |
| 32 | +```python |
| 33 | +import numpy as np |
| 34 | + |
| 35 | +x = np.random.randn(3, 4) |
| 36 | +y = np.random.randn(4, 3) |
| 37 | +z = np.dot(x, y) |
| 38 | +w = z.sum() |
| 39 | +``` |
| 40 | + |
| 41 | +When we trace this program with the compat layer, the semantics of the |
| 42 | +program would stay the same, but the implementation would be equivalent to |
| 43 | + |
| 44 | +```python |
| 45 | +import torch |
| 46 | +x = torch.randn(3, 4, dtype=torch.float64) |
| 47 | +y = torch.randn(4, 3, dtype=torch.float64) |
| 48 | +z = torch.matmul(x, y) |
| 49 | +w = z.sum() |
| 50 | +``` |
| 51 | + |
| 52 | +Here, we can already spot a couple differences between NumPy and PyTorch. |
| 53 | +The most obvious one is that the default dtype in NumPy is `float64` rather than |
| 54 | +`float32`. The less obvious is very sneakily hiding in the last line. |
| 55 | + |
| 56 | +```python |
| 57 | +>>> type(w) |
| 58 | +<class 'numpy.float64'> |
| 59 | +``` |
| 60 | + |
| 61 | +Reductions and similar operations in NumPy return the infamous NumPy scalars. |
| 62 | +We'll discuss these and other NumPy quirks and how we dealt with them in the |
| 63 | +[design decision section](#design-decisions). |
| 64 | + |
| 65 | + |
| 66 | +Let's now have a look at a toy example of how this layer would be used. |
| 67 | +```python |
| 68 | +import torch |
| 69 | +import numpy as np |
| 70 | +t1 = torch.tensor([1, 3, 5]) |
| 71 | +t2 = torch.exp(t) |
| 72 | +# Now say the user has some code lying around which uses NumPy: |
| 73 | +def fn(x, y): |
| 74 | + return np.multiply(x, y).sum() |
| 75 | + |
| 76 | +result = fn(t1, t2) |
| 77 | +t_results = torch.empty(5, dtype=torch.float64) |
| 78 | +t_results[0] = result # store the result in a torch.Tensor |
| 79 | +``` |
| 80 | + |
| 81 | +Note that this code mixing NumPy and PyTorch already works in eager mode with |
| 82 | +CPU tensors, as `torch.Tensor` implements the `__array__` method. Now, the |
| 83 | +compatibility layer allows us to trace through it. In order to do that, there |
| 84 | +would be no necessary changes, other than simply ask `torch.compile` to trace |
| 85 | +through it: |
| 86 | + |
| 87 | +```python |
| 88 | +@compile |
| 89 | +def fn(x, y): |
| 90 | + return np.multiply(x, y).sum() |
| 91 | +``` |
| 92 | + |
| 93 | +Then, TorchDynamo would will cast `x` and `y` to our internal implementation of `ndarray`, |
| 94 | +and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch` |
| 95 | +functions, effectively turning this function into a pure PyTorch function. |
| 96 | + |
| 97 | +### Design decisions |
| 98 | + |
| 99 | +The main ideas driving the design of this compatibility layer are the following: |
| 100 | + |
| 101 | +1. The goal is to transform valid NumPy and mixed PyTorch-NumPy programs into |
| 102 | + their equivalent PyTorch-only execution. |
| 103 | +2. The behavior of the layer should be as close to that of NumPy as possible |
| 104 | +3. The layer follows the most recent NumPy release |
| 105 | + |
| 106 | +The following design decisions follow from these: |
| 107 | + |
| 108 | +**A superset of NumPy**. NumPy has a number of well-known edge-cases (as does |
| 109 | +PyTorch, like spotty support for `float16` on CPU and `complex32` in general). |
| 110 | +The decision to translate only valid NumPy programs, often allows us to |
| 111 | +implement a superset of the functionality of NumPy with more predictable and |
| 112 | +consistent behavior than NumPy itself has. |
| 113 | + |
| 114 | +**Exceptions may be different**. We avoid entirely modelling the exception |
| 115 | +system in NumPy. As seen in the implementation of PrimTorch, modelling the |
| 116 | +error cases of a given system is terribly difficult. We avoid this altogether |
| 117 | +and we choose not to offer any guarantee here. |
| 118 | + |
| 119 | +**Default dtypes**. One of the most common issues that bites people when migrating their |
| 120 | +codebases from NumPy to JAX is the default dtype changing from `float64` to |
| 121 | +`float32`. So much so that this is noted as one of |
| 122 | +[JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). |
| 123 | +Following the spirit of making everything match NumPy by default, we choose the |
| 124 | +NumPy default dtype whenever the `dtype` was not made explicit in a factory function. |
| 125 | +We also provide a function `set_default_dtype` that allows to change this behavior |
| 126 | +dynamically. |
| 127 | + |
| 128 | +**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks |
| 129 | +like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. |
| 130 | +Upon closer inspection, one finds that it also has |
| 131 | +[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. |
| 132 | +NumPy scalars are similar to Python scalars but with a fixed precision and |
| 133 | +array-like methods attached. NumPy scalars are NumPy's preferred return class |
| 134 | +for reductions and other operations that return just one element. |
| 135 | +NumPy scalars do not play particularly well with |
| 136 | +computations on devices like GPUs, as they live on CPU. Implementing NumPy |
| 137 | +scalars would mean that we need to synchronize after every `sum()` call, which |
| 138 | +would be terrible performance-wise. In this implementation, we choose to represent |
| 139 | +NumPy scalars as 0-D arrays. This may cause small divergences in some cases. For example, |
| 140 | +consider the following NumPy behavior: |
| 141 | + |
| 142 | +```python |
| 143 | +>>> np.int32(2) * [1, 2, 3] # scalar decays to a python int |
| 144 | +[1, 2, 3, 1, 2, 3] |
| 145 | + |
| 146 | +>>> np.asarray(2) * [1, 2, 3] # zero-dim array is an array-like |
| 147 | +array([2, 4, 6]) |
| 148 | +``` |
| 149 | + |
| 150 | +We don't expect these to pose a big issue in practice. Note that in the |
| 151 | +proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`. |
| 152 | +In general, we try to avoid unnecessary graph breaks whenever we can. For |
| 153 | +example, we may choose to return a tensor of shape `(2, *)` rather than a list |
| 154 | +of pairs, to avoid a graph break. |
| 155 | + |
| 156 | +**Type promotion**. Another not-so-well-known fact of NumPy's dtype system and casting rules |
| 157 | +is that it is data-dependent. Python scalars can be used in pretty much any NumPy |
| 158 | +operation, being able to call any operation that accepts a 0-D array with a |
| 159 | +Python scalar. If you provide an operation with a Python scalar, these will be |
| 160 | +cast to the smallest dtype they can be represented in, and only then will they |
| 161 | +participate in type promotion. This allows for for some rather interesting behaviour |
| 162 | +```python |
| 163 | +>>> np.asarray([1], dtype=np.int8) + 127 |
| 164 | +array([128], dtype=int8) |
| 165 | +>>> np.asarray([1], dtype=np.int8) + 128 |
| 166 | +array([129], dtype=int16) |
| 167 | +``` |
| 168 | +This data-dependent type promotion will be removed in NumPy 2.0 (planned for Dec'23), and will be |
| 169 | +replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) |
| 170 | +(already implemented in NumPy, it needs to be enabled via a private global switch now). |
| 171 | +For simplicity and to be forward-looking, we chose to implement the |
| 172 | +type promotion behaviour proposed in NEP 50, which is much closer to that of |
| 173 | +PyTorch. |
| 174 | + |
| 175 | +Note that the decision of going with NEP 50 complements the previous one of |
| 176 | +returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not |
| 177 | +participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50): |
| 178 | +```python |
| 179 | +int64_0d_array = np.array(1, dtype=np.int64) |
| 180 | +np.result_type(np.int8, int64_0d_array) == np.int8 |
| 181 | +``` |
| 182 | + |
| 183 | +**Versioning**. It should be clear from the previous points that NumPy has a |
| 184 | +fair amount of questionable behavior and legacy pain points. It is for this reason that |
| 185 | +we decided that rather than fighting these, we would declare that the compat |
| 186 | +layer follows the behavior of NumPy's most recent release (even, in some cases, |
| 187 | +of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its |
| 188 | +main functions are, we do not expect this to become a big maintenance burden. |
| 189 | +If anything, it should make our lives easier, as some parts of NumPy will soon |
| 190 | +be simplified, saving us the pain of having to implement all the pre-existing |
| 191 | +corner-cases. |
| 192 | + |
| 193 | +**Randomness**. PyTorch and NumPy use different random number generation methods. |
| 194 | +In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html) |
| 195 | +with a `Generator` object which has sampling methods on it. The current compat |
| 196 | +layer does not implement this new API, as the default bit generator in NumPy is |
| 197 | +`PCG64`, while on PyTorch we use `MT19937` on CPU and `Philox` on non-CPU devices. |
| 198 | +From this, it follows that this API will not give any reproducibility |
| 199 | +guarantees when it comes to randomness. |
| 200 | + |
| 201 | + |
| 202 | +## The `torch_np` module |
| 203 | + |
| 204 | +The bulk of the work went into implementing a system that allows us to |
| 205 | +implement NumPy operations in terms of those of PyTorch. The main design goals |
| 206 | +here were |
| 207 | + |
| 208 | +1. Implement *most* of NumPy's API |
| 209 | +2. Preserve NumPy semantics as much as possible |
| 210 | + |
| 211 | +We say *most* of NumPy's API, because NumPy's API is not only massive, but also |
| 212 | +there are parts of it which cannot be implemented in PyTorch. For example, |
| 213 | +NumPy has support for arrays of string, datetime, structured and other dtypes. |
| 214 | +Negative strides are another example of a feature that is not supported in PyTorch. |
| 215 | +We put together a list of things that are out of the scope of this project in the |
| 216 | +[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). |
| 217 | + |
| 218 | +For the bulk of the functions, we started by prioritizing the most common |
| 219 | +operations. Then, when bringing tests from the NumPy test suite, we triaged |
| 220 | +and prioritized how important it was to fix each failure we found. Doing this |
| 221 | +iteratively, we ended up with a small list of differences between the NumPy and |
| 222 | +PyTorch APIs, which we prioritized by hand. That list and the prioritization |
| 223 | +discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). |
| 224 | + |
| 225 | +**Visibility of the module** For simplicity, this RFC assumes that the |
| 226 | +`torch_np` module will not be public, as the initial suggestion for it to be |
| 227 | +made public was met with mixed opinions. This topic can be revisited in the |
| 228 | +future if desired. |
| 229 | +We discuss these in the section [unresolved questions](#unresolved-questions). |
| 230 | + |
| 231 | +### Annotation-based preprocessing |
| 232 | + |
| 233 | +NumPy accepts virtually anything that smells like an array as an input. |
| 234 | + |
| 235 | +```python |
| 236 | +>>> np.add(1, 3) |
| 237 | +4 |
| 238 | +>>> np.add([1., 2., 3.], 5) |
| 239 | +array([6., 7., 8.]) |
| 240 | +>>> np.concatenate([[1, 2, 3], [4, 5, 6]]) |
| 241 | +array([1, 2, 3, 4, 5, 6]) |
| 242 | +``` |
| 243 | + |
| 244 | +NumPy calls all these objects `array_like` objects. |
| 245 | +To implement NumPy in terms of PyTorch, for any operation we would need to map |
| 246 | +inputs into tensors, perform the operations, and then wrap the tensor into |
| 247 | +a `torch_np.ndarray` (more on this class later). |
| 248 | + |
| 249 | +To avoid all this code repetition, we implement the functions in two steps. |
| 250 | + |
| 251 | +First, we implement functions with the NumPy signature, but assuming that in |
| 252 | +place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) |
| 253 | +they simply accept `torch.Tensor` and PyTorch-land objects and return |
| 254 | +`torch.Tensor`s. For example, we would implement `np.diag` as |
| 255 | + |
| 256 | +```python |
| 257 | +def diag(v, k=0): |
| 258 | + return torch.diag(v, k) |
| 259 | +``` |
| 260 | + |
| 261 | +In this layer, if a NumPy function is composite (calls other NumPy functions |
| 262 | +internally), we can simply vendor its implementation, and have it call our |
| 263 | +PyTorch-land implementations of these functions. In other words, at this level, |
| 264 | +functions are composable, as they are simply regular PyTorch functions. |
| 265 | +All these implementations are internal, and are not meant to be seen or used |
| 266 | +by the end user. |
| 267 | + |
| 268 | +The second step is then done via type annotations and a decorator. Each type |
| 269 | +annotation has an associated function from NumPy-land into PyTorch-land. This |
| 270 | +function converts the set of inputs accepted by NumPy for that argument into a |
| 271 | +PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example, |
| 272 | +for `np.diag` we would write |
| 273 | + |
| 274 | +```python |
| 275 | +def diag(v: ArrayLike, k=0): |
| 276 | + return torch.diag(v, k) |
| 277 | +``` |
| 278 | + |
| 279 | +Then, we wrap these Python-land functions with a `normalizer` decorator and |
| 280 | +expose them in the `torch_np` module. This decorator is in charge of gathering |
| 281 | +all the inputs at runtime and normalizing them (i.e., converting `torch_np` |
| 282 | +objects to PyTorch counterparts) according to their annotations. |
| 283 | + |
| 284 | +We currently have four annotations (and small variations of them): |
| 285 | +- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a |
| 286 | + scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. |
| 287 | +- `DTypeLike`: Takes a `torch_np` dtype, and any other object that Numpy dtypes |
| 288 | + accept (strings, typecodes...) and returns a PyTorch dtype. |
| 289 | +- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or |
| 290 | + an `ndarray`) and returns a tuple. |
| 291 | +- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used |
| 292 | + to implement the `out` keyword. |
| 293 | + |
| 294 | +Note that none of the code in this implementation makes use of NumPy. We are |
| 295 | +writing `torch_np.ndarray` above to make more explicit our intent, but there |
| 296 | +shouldn't be any ambiguity. |
| 297 | + |
| 298 | +**Implementing `out`**: In PyTorch, the `out` kwarg is a keyword-only argument. |
| 299 | +It is for this reason that, in PrimTorch, we were able to implement it as [a |
| 300 | +decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). |
| 301 | +This is not the case in NumPy. In NumPy, `out` can be used both as a positional |
| 302 | +and a keyword argument, and is often interleaved with other parameters. This is |
| 303 | +the reason why we use the `OutArray` annotation to mark these. We then |
| 304 | +implement the `out` semantics in the `@normalizer` wrapper in a generic way. |
| 305 | + |
| 306 | +**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two |
| 307 | +sets of functions that are particularly regular. For these functions, we |
| 308 | +implement support for their arguments in a generic way as a preprocessing or |
| 309 | +postprocessing step. |
| 310 | + |
| 311 | +**The `ndarray` class** Once we have all the free functions implemented as |
| 312 | +functions from `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the |
| 313 | +methods from the `ndarray` class is rather simple. We simply register all the |
| 314 | +free functions as methods or dunder methods appropriately. We also forward the |
| 315 | +properties of `ndarray to the corresponding properties of `torch.Tensor` and we |
| 316 | +are done. This creates a circular dependency which we break with a local |
| 317 | +import. |
| 318 | + |
| 319 | +### Testing |
| 320 | + |
| 321 | +The testing of the framework was done via ~~copying~~ vendoring tests from the |
| 322 | +NumPy test suite. Then, we would replace the NumPy imports with `torch_np` |
| 323 | +imports. The failures on these tests were then triaged, and either fixed or marked |
| 324 | +`xfail` depending on our assessment of the priority of implementing a fix. |
| 325 | + |
| 326 | +In the end, to have a last check that this tool was sound, we pulled five |
| 327 | +examples of NumPy code from different sources and ran it with this library (eager mode execution). |
| 328 | +We were able to run the five examples successfully with close to no code changes. |
| 329 | +You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop). |
| 330 | + |
| 331 | +### Limitations |
| 332 | + |
| 333 | +A number of known limitations are tracked in the second part of the |
| 334 | +[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). |
| 335 | +When landing this RFC, we will create a comprehensive document with the differences |
| 336 | +between NumPy and `torch_np`. |
| 337 | + |
| 338 | +### Beyond plain NumPy |
| 339 | + |
| 340 | +**GPU**. The current implementation so far only been implemented and tested on |
| 341 | +CPU. We expect GPU coverage to be as good as the coverage we have with CPU-GPU |
| 342 | +matching tests in the PyTorch test suite. If the original tensors are on GPU, |
| 343 | +the execution should be performed fully on GPU. |
| 344 | + |
| 345 | +**Gradients**. We have not tested gradient tracking either as we are still to |
| 346 | +find some good examples on which to test it, but it should be a simple |
| 347 | +corollary of all this effort. If the original tensors fed into a function |
| 348 | +have `requires_grad=True`, the tensors will track the gradients of the internal |
| 349 | +implementation and then the user can differentiate through their NumPy code. |
| 350 | + |
| 351 | +### Bindings to TorchDynamo |
| 352 | + |
| 353 | +The bindings for NumPy at the TorchDynamo level are currently being developed in |
| 354 | +[pytorch#95849](https://github.com/pytorch/pytorch/pull/95849). |
| 355 | + |
| 356 | + |
| 357 | +## Unresolved questions |
| 358 | + |
| 359 | +A question was left open in the initial discussion. Should the module |
| 360 | +`torch_np` be publicly exposed as `torch.numpy` or not? |
| 361 | + |
| 362 | +A few arguments in favor of making it public: |
| 363 | +* People could use it in their NumPy programs just by changing the import to |
| 364 | + `import torch.numpy as np`. This could be a selling point similar to JAX's |
| 365 | + `jax.numpy`, which could incentivize adoption. |
| 366 | +* People would not need to use the whole PyTorch 2.0 stack to start using |
| 367 | + PyTorch in their codebases |
| 368 | + * See [this experiment in scikit-learn](https://github.com/scikit-learn/scikit-learn/pull/25956) |
| 369 | + where they got a 7x speed-up on CPU on a layer just by using `torch.linalg`. |
| 370 | +* Since the layer is rather thin and in pure Python, if there are bugs, |
| 371 | + external contributors could easily help fixing them or extend the supported |
| 372 | + functionality. |
| 373 | + |
| 374 | +A few arguments against: |
| 375 | +* The compat introduces a number of type conversions that may produce somewhat |
| 376 | + slow code when used in eager mode. |
| 377 | + * [Note] Keeping this in mind, we tried to use as few operators as possible, |
| 378 | + in the implementation, to make it reasonably fast in eager mode. |
| 379 | +* Exposing `torch.numpy` would create a less performant secondary entry point |
| 380 | + to many of the functions in PyTorch. This could be a trap for new users. |
0 commit comments