From 4e9d23075cf6c653f662a4019d70be5e9c42170a Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 11 Apr 2023 11:37:48 +0000 Subject: [PATCH 01/12] First RFC draft --- RFC.md | 182 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 RFC.md diff --git a/RFC.md b/RFC.md new file mode 100644 index 00000000..8d96f31a --- /dev/null +++ b/RFC.md @@ -0,0 +1,182 @@ +# Summary +This RFC describes a proposal for a translation layer from NumPy into PyTorch. In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, the `np`, `np.linalg`, `np.fft` modules, etc) using `torch.Tensor` and PyTorch ops as backend. + +The this project has two main goals: +1. Have a `torch.numpy` submodule, similar to `jax.numpy` that serves as a drop-in replacement for NumPy when imported as `import torch.numpy as np`. +2. Have TorchDynamo understand and use this layer to be able to trace through NumPy programs as if they were written in PyTorch + +Two corollaries of this work should be: +1. Given NumPy code, one should be able to differentiate through it using PyTorch's autograd engine +2. Given NumPy code, one should be able to execute it on CUDA + +The work is being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). + +# The Translation Layer +In this section we discuss the ideas behind design and implementation of the translation layer from PyTorch to NumPy + +## Two examples of expected usage +Let's start with some examples. + +Consider the following snippet: +```python +import numpy as np + +x = np.random.randn(3, 4) +y = np.random.randn(4, 3) +z = np.dot(x, y) +w = z.sum() +``` + +By changing the first line to `import torch.numpy as np`, the semantics of the program would stay the same, but the implementation would be equivalent to + +```python +import torch +x = torch.randn(3, 4, dtype=torch.float64) +y = torch.randn(4, 3, dtype=torch.float64) +z = torch.matmul(x, y) +w = z.sum() +``` + +Here we already see a couple differences between NumPy and PyTorch. The most obvious one is that the default dtype in NumPy is `float64` rather than `float32`. The less obvious is very sneakily hiding in the last line. + +```python +>>> type(w) + +``` + +Reductions and similar operations in NumPy return the infamous NumPy scalars. We'll discuss these and other NumPy quirks and how we dealt with them in the sequel. + +As expected, this layer also allows for combining NumPy code and PyTorch code. + +```python +import torch +import numpy as np +t1 = torch.tensor([1, 3, 5]) +t2 = torch.exp(t) +# Now say the user has some code lying around which uses NumPy: +def fn(x, y): + return np.multiply(x, y).sum() + +result = fn(t1, t2) +t_results = torch.empty(5, dtype=torch.float64) +t_results[0] = result # store the result in a torch.Tensor +``` + +This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements the `__array__` method. For it to work with the compatibility layer, we would need to wrap and unwrap the inputs / outputs. This could be done modifying `fn` as + +```python +def fn(x, y): + x = np.asarray(x) + y = np.asarray(y) + ret = np.multiply(x, y).sum() + return ret.tensor.numpy() +``` + +Note that this wrapping / unwrapping process can be easily automated via a decorator. +Even more, if a user wants to use PyTorch as a backend in a code that mixes PyTorch and NumPy, it will mostly be the case that it is because they want to trace through that code. In that setting, TorchDynamo will be able to automatically do the wrapping/unwrapping. + +## The `torch.numpy` module +The bulk of the work went into implementing a system that allows us to implement NumPy operations in terms of those of PyTorch. The main design goals were + +1. Implement *most* of NumPy's API +2. Preserve NumPy semantics as much as possible + +We say *most* of NumPy's API, because NumPy's API is not only massive, but also there are parts of it which cannot be implemented in PyTorch. For example, NumPy has support for arrays of strings, dates, and other `dtype`s that PyTorch does not consider. Negative strides are other example. We put together a list of things that are out of the scope of this project in the [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). + +For the bulk of the functions, we started by prioritizing most common operations. Then, when bringing tests from the NumPy test suit and running them, we would triage and prioritize how important was to fix each failure we found. Iterating this process, we ended up with a small list of differences between the NumPy and the PyTorch API which we sorted out by hand and finished implementing. That list and the prioritization discussion can be found in [the first few posts of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). + +The second point of preserving NumPy semantics as much as possible will be used in the sequel to discuss some points like the default dtypes that are used throughout the implementation. + +### Annotation-based preprocessing +NumPy accepts virtually anything that smells like an array as input to its operators +```python +>>> np.add(1, 3) +4 +>>> np.add([1., 2., 3.], 5) +array([6., 7., 8.]) +>>> np.concatenate([[1, 2, 3], [4, 5, 6]]) +array([1, 2, 3, 4, 5, 6]) +``` + +To implement NumPy in terms of PyTorch, for any operation we would need to put the inputs into tensors, perform the operations, and then wrap the tensor into a `torch.numpy.ndarray` (more on this class later). + +To avoid all this code repetition, we implement the functions in two steps. + +First, we implement functions with the NumPy signature, but assuming that in place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) they simply accept `torch.Tensor` and PyTorch-land objects and return `torch.Tensor`s. For example, we would implement `np.diag` as +```python +def diag(v, k=0): + return torch.diag(v, k) +``` +In this layer, if a NumPy function is composite (calls other NumPy functions internally), we can simply vendor its implementation, and have it call our PyTorch-land implementations of these functions. In other words, at this level, functions are composable, as any set of functions implemented purely in PyTorch. All these implementations are internal, and are not meant to be seen or used by the final user. + +The second step is then done via type annotations and a decorator. Each type annotation has then a map NumPy-land -> PyTorch-land associated, that maps the set of inputs accepted by NumPy for that argument into a PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example, for `np.diag` we would write +```python +def diag(v: ArrayLike, k=0): + return torch.diag(v, k) +``` + +Then, we would wrap these Python-land functions in a `normalizer` decorator and expose them in the public `torch.np` module. This decorator is in charge of gathering all the inputs at runtime and normalizing them according to their annotations. + +We currently have four annotations (and small variations of them): +- `ArrayLike`: The input can be a `torch.numpy.array`, a list of lists, a scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. +- `DTypeLike`: Takes a `torch.numpy` dtype and returns a PyTorch dtype. +- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or an `ndarray`) and returns a tuple. +- `OutArray`: Asserts that the input is a `torch.numpy.ndarray`. This is used to implement the `out` arg. + +Note that none of the code here makes use of NumPy. We are writing `torch.numpy.ndarray` above to make more explicit our intents, but there shouldn't be any ambiguity here. + +**OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` +**OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that are not being implemented? We could then assert that either that parameter has not been provided, and if it has, it has the same value as the default. The goal here would be to either use all the args of a function in its implementation, or mark explicitly those that we don't use. + +**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a keyword-only argument. It is for this reason that, in PrimTorch, we were able to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). This is not the case in NumPy. In NumPy `out` is a positional arg that is often interleaved with other parameters. This is the reason why we use the `OutArray` label to mark these. We then implement the `out` semantics in the `@normalizer` wrapper in a generic way. + +**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two sets of functions that are particularly regular. For these functions, we implement (some of) their args in a generic way. We then simply forward the computations to PyTorch, perhaps working around some PyTorch limitations. + +### The `ndarray` class +Once we have all the free functions implemented, implementing an `ndarray` class is rather simple. We simply register all the free functions as methods or dunder methods apropriately. We also forward the properties to the properties within the PyTorch tensor and we are done. + +### DTypes +**Default dtypes**. One of the issues that most often user when moving their codebase from NumPy to JAX was the default dtype changing from `float64` to `float32`. So much so, that this is one noted as one of [JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). Following the spirit of making everything match NumPy by default, we choose the NumPy defaults whenever the `dtype` was not chosen in a factory function. + +**TODO(Lezcano)**: I just realised that we do not have a clean way to change the default dtype of `torch.numpy` to those from PyTorch. We should implement that utility flag, similar to [`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). Perhaps call it `torch.numpy.use_torch_defaults()` and then add a way for users to be able to set their own int/float/complex defaults. +**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also use them anywhere else -> Check + +**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or `np.longdouble`. Upon closer inspection, one finds that it also has [NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. NumPy scalars are similar to Python scalars but with a set width. NumPy scalars are NumPy's preferred return class for reductions and other operations that return just one element. NumPy scalars do not play particularly well with computations on devices like GPUs, as they live on CPU. Implementing NumPy scalars would mean that we need to synchronize after every `sum()` call, which is less-than-good. Instead, whenever a NumPy scalar would be returned, we will return a 0-D tensor, as PyTorch does. + +**Type promotion**. Another not-so-well-known fact of NumPy's cast system is that it is data-dependent. Python scalars can be used in pretty much any NumPy operation, being able to call any operation that accepts a 0-D array with a Python scalar. If you provide an operation with a Python scalar, these will be casted to the smallest dtype that can represent them, and then they will participate in type promotion, allowing for some rather intersting behaviour +``` +>>> np.asarray([1], dtype=np.int8) + 127 +array([128], dtype=int8) +>>> np.asarray([1], dtype=np.int8) + 128 +array([129], dtype=int16) +``` +This dependent type promotion will be deprecated NumPy 2.0, and will be replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). As such, to be forward-looking and for simplicity, we chose to implement the type promotion behaviour proposed in NEP 50, which is much closer to that of Pytorch. + +Note that the decision of going with NEP 50 complements the previous one of returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not participate in type promotion in NumPy (but will do in NumPy 2.0): +``` +int64_0d_array = np.array(1, dtype=np.int64) +np.result_type(np.int8, int64_0d_array) == np.int8 +``` + +## Testing +The testing of the framework was done via ~~copying~~ vendoring tests from the NumPy test suit. Then, we would replace the NumPy imports for imports with `torch.numpy`. The failures on these tests were then triaged and discussed the priority of fixing each of them. + +In the (near) future, we plan to get some real world examples and run them through the library, to test its coverage and correctness. + +## Limitations +One of the known limitations of this approach is the efficiency in eager. Similar to PrimTorch, sometimes we needed to work around some limitations of PyTorch (e.g. support for some operations for `float16`) or some ways PyTorch deviates from NumPy by implementing things manually calling several `torch` operations. This, when executed in eager mode and, in particular, on CUDA devices, will result on a perf-hit. To alleviate this, we tried to dispatch NumPy functions to PyTorch functions with as few indirections as possible, to alleviate the number of kernels called when executed on eager mode. + +There are some known limitations. Some of them are tracked in the second part of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). When landing all this, we will create a comprehensive document with the differences between NumPy and `torch.numpy`. + +## Beyond NumPy +**CUDA**. The current implementation has just been implemented and tested on CPU. We expect CUDA coverage to be as good as the coverage we have with CPU matching CUDA. In the NumPy-only example in the introduction, given that no explicit `device` kwarg is used anywhere in this module, CUDA execution could be turned on via `with torch.device('cuda'):`. In the PyTorch+NumPy example, if the original tensors are on GPU, the whole execution should be performed on the GPU. + +**TODO(Lezcano)**. We should probably test CUDA on the tests. + +**Gradients**. We have not tested gradient tracking either as we are still to find some good examples on which to test it, but it should be a simple corollary of all this effort. In the PyTorch+NumPy scenario, if the original tensors fed into the function do have `requires_grad=True`, the tensors will track the gradients of the internal implementation and then the user could differentiate through the NumPy code. We do not have a way to turn the `requires_grad` flag in the all-NumPy case. Note that this is expected as this would require exposing all the autograd machinery from PyTorch into the API. If a user wants to compute gradients in their program, we expect them to wrap it in a function and apply the PyTorch-NumPy approach. + +**TODO(Lezcano)**. Picking up simple NumPy programs from the internet would be good for these autograd tests. + + +# Bindings to TorchDyamo +**TODO(Lezcano)**: The PR is not there yet cf. [#95849](https://github.com/pytorch/pytorch/pull/95849). From 9803d45b5a202e9c7cc7c97697f033deadef728c Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 11 Apr 2023 13:31:38 +0000 Subject: [PATCH 02/12] Break lines --- RFC.md | 248 +++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 196 insertions(+), 52 deletions(-) diff --git a/RFC.md b/RFC.md index 8d96f31a..cd578aab 100644 --- a/RFC.md +++ b/RFC.md @@ -1,20 +1,30 @@ # Summary -This RFC describes a proposal for a translation layer from NumPy into PyTorch. In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, the `np`, `np.linalg`, `np.fft` modules, etc) using `torch.Tensor` and PyTorch ops as backend. + +This RFC describes a proposal for a translation layer from NumPy into PyTorch. +In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, +the `np`, `np.linalg`, `np.fft` modules, etc) using `torch.Tensor` and PyTorch +ops as backend. The this project has two main goals: -1. Have a `torch.numpy` submodule, similar to `jax.numpy` that serves as a drop-in replacement for NumPy when imported as `import torch.numpy as np`. -2. Have TorchDynamo understand and use this layer to be able to trace through NumPy programs as if they were written in PyTorch +1. Have a `torch.numpy` submodule, similar to `jax.numpy` that serves as a + drop-in replacement for NumPy when imported as `import torch.numpy as np`. +2. Have TorchDynamo understand and use this layer to be able to trace through + NumPy programs as if they were written in PyTorch Two corollaries of this work should be: -1. Given NumPy code, one should be able to differentiate through it using PyTorch's autograd engine +1. Given NumPy code, one should be able to differentiate through it using + PyTorch's autograd engine 2. Given NumPy code, one should be able to execute it on CUDA The work is being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). # The Translation Layer -In this section we discuss the ideas behind design and implementation of the translation layer from PyTorch to NumPy -## Two examples of expected usage +In this section we discuss the ideas behind design and implementation of the +translation layer from PyTorch to NumPy + +## The two expected uses + Let's start with some examples. Consider the following snippet: @@ -27,7 +37,8 @@ z = np.dot(x, y) w = z.sum() ``` -By changing the first line to `import torch.numpy as np`, the semantics of the program would stay the same, but the implementation would be equivalent to +By changing the first line to `import torch.numpy as np`, the semantics of the +program would stay the same, but the implementation would be equivalent to ```python import torch @@ -37,14 +48,18 @@ z = torch.matmul(x, y) w = z.sum() ``` -Here we already see a couple differences between NumPy and PyTorch. The most obvious one is that the default dtype in NumPy is `float64` rather than `float32`. The less obvious is very sneakily hiding in the last line. +Here we already see a couple differences between NumPy and PyTorch. The most +obvious one is that the default dtype in NumPy is `float64` rather than +`float32`. The less obvious is very sneakily hiding in the last line. ```python >>> type(w) ``` -Reductions and similar operations in NumPy return the infamous NumPy scalars. We'll discuss these and other NumPy quirks and how we dealt with them in the sequel. +Reductions and similar operations in NumPy return the infamous NumPy scalars. +We'll discuss these and other NumPy quirks and how we dealt with them in the +sequel. As expected, this layer also allows for combining NumPy code and PyTorch code. @@ -62,7 +77,10 @@ t_results = torch.empty(5, dtype=torch.float64) t_results[0] = result # store the result in a torch.Tensor ``` -This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements the `__array__` method. For it to work with the compatibility layer, we would need to wrap and unwrap the inputs / outputs. This could be done modifying `fn` as +This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements +the `__array__` method. For it to work with the compatibility layer, we would +need to wrap and unwrap the inputs / outputs. This could be done modifying `fn` +as ```python def fn(x, y): @@ -73,21 +91,41 @@ def fn(x, y): ``` Note that this wrapping / unwrapping process can be easily automated via a decorator. -Even more, if a user wants to use PyTorch as a backend in a code that mixes PyTorch and NumPy, it will mostly be the case that it is because they want to trace through that code. In that setting, TorchDynamo will be able to automatically do the wrapping/unwrapping. +Even more, if a user wants to use PyTorch as a backend in a code that mixes +PyTorch and NumPy, it will mostly be the case that it is because they want to +trace through that code. In that setting, TorchDynamo will be able to +automatically do the wrapping/unwrapping. ## The `torch.numpy` module -The bulk of the work went into implementing a system that allows us to implement NumPy operations in terms of those of PyTorch. The main design goals were + +The bulk of the work went into implementing a system that allows us to +implement NumPy operations in terms of those of PyTorch. The main design goals +were 1. Implement *most* of NumPy's API 2. Preserve NumPy semantics as much as possible -We say *most* of NumPy's API, because NumPy's API is not only massive, but also there are parts of it which cannot be implemented in PyTorch. For example, NumPy has support for arrays of strings, dates, and other `dtype`s that PyTorch does not consider. Negative strides are other example. We put together a list of things that are out of the scope of this project in the [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). - -For the bulk of the functions, we started by prioritizing most common operations. Then, when bringing tests from the NumPy test suit and running them, we would triage and prioritize how important was to fix each failure we found. Iterating this process, we ended up with a small list of differences between the NumPy and the PyTorch API which we sorted out by hand and finished implementing. That list and the prioritization discussion can be found in [the first few posts of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). - -The second point of preserving NumPy semantics as much as possible will be used in the sequel to discuss some points like the default dtypes that are used throughout the implementation. +We say *most* of NumPy's API, because NumPy's API is not only massive, but also +there are parts of it which cannot be implemented in PyTorch. For example, +NumPy has support for arrays of strings, dates, and other `dtype`s that PyTorch +does not consider. Negative strides are other example. We put together a list +of things that are out of the scope of this project in the +[following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). + +For the bulk of the functions, we started by prioritizing most common +operations. Then, when bringing tests from the NumPy test suit and running +them, we would triage and prioritize how important was to fix each failure we +found. Iterating this process, we ended up with a small list of differences +between the NumPy and the PyTorch API which we sorted out by hand and finished +implementing. That list and the prioritization discussion can be found in +[the first few posts of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). + +The second point of preserving NumPy semantics as much as possible will be used +in the sequel to discuss some points like the default dtypes that are used +throughout the implementation. ### Annotation-based preprocessing + NumPy accepts virtually anything that smells like an array as input to its operators ```python >>> np.add(1, 3) @@ -98,85 +136,191 @@ array([6., 7., 8.]) array([1, 2, 3, 4, 5, 6]) ``` -To implement NumPy in terms of PyTorch, for any operation we would need to put the inputs into tensors, perform the operations, and then wrap the tensor into a `torch.numpy.ndarray` (more on this class later). +To implement NumPy in terms of PyTorch, for any operation we would need to put +the inputs into tensors, perform the operations, and then wrap the tensor into +a `torch.numpy.ndarray` (more on this class later). To avoid all this code repetition, we implement the functions in two steps. -First, we implement functions with the NumPy signature, but assuming that in place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) they simply accept `torch.Tensor` and PyTorch-land objects and return `torch.Tensor`s. For example, we would implement `np.diag` as +First, we implement functions with the NumPy signature, but assuming that in +place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) +they simply accept `torch.Tensor` and PyTorch-land objects and return +`torch.Tensor`s. For example, we would implement `np.diag` as ```python def diag(v, k=0): return torch.diag(v, k) ``` -In this layer, if a NumPy function is composite (calls other NumPy functions internally), we can simply vendor its implementation, and have it call our PyTorch-land implementations of these functions. In other words, at this level, functions are composable, as any set of functions implemented purely in PyTorch. All these implementations are internal, and are not meant to be seen or used by the final user. - -The second step is then done via type annotations and a decorator. Each type annotation has then a map NumPy-land -> PyTorch-land associated, that maps the set of inputs accepted by NumPy for that argument into a PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example, for `np.diag` we would write +In this layer, if a NumPy function is composite (calls other NumPy functions +internally), we can simply vendor its implementation, and have it call our +PyTorch-land implementations of these functions. In other words, at this level, +functions are composable, as any set of functions implemented purely in +PyTorch. All these implementations are internal, and are not meant to be seen +or used by the final user. + +The second step is then done via type annotations and a decorator. Each type +annotation has then a map NumPy-land -> PyTorch-land associated, that maps the +set of inputs accepted by NumPy for that argument into a PyTorch-land object +(think a `torch.Tensor` or a PyTorch dtype). For example, for `np.diag` we +would write ```python def diag(v: ArrayLike, k=0): return torch.diag(v, k) ``` -Then, we would wrap these Python-land functions in a `normalizer` decorator and expose them in the public `torch.np` module. This decorator is in charge of gathering all the inputs at runtime and normalizing them according to their annotations. +Then, we would wrap these Python-land functions in a `normalizer` decorator and +expose them in the public `torch.np` module. This decorator is in charge of +gathering all the inputs at runtime and normalizing them according to their +annotations. We currently have four annotations (and small variations of them): -- `ArrayLike`: The input can be a `torch.numpy.array`, a list of lists, a scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. +- `ArrayLike`: The input can be a `torch.numpy.array`, a list of lists, a + scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. - `DTypeLike`: Takes a `torch.numpy` dtype and returns a PyTorch dtype. -- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or an `ndarray`) and returns a tuple. -- `OutArray`: Asserts that the input is a `torch.numpy.ndarray`. This is used to implement the `out` arg. +- `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or + an `ndarray`) and returns a tuple. +- `OutArray`: Asserts that the input is a `torch.numpy.ndarray`. This is used + to implement the `out` arg. -Note that none of the code here makes use of NumPy. We are writing `torch.numpy.ndarray` above to make more explicit our intents, but there shouldn't be any ambiguity here. +Note that none of the code here makes use of NumPy. We are writing +`torch.numpy.ndarray` above to make more explicit our intents, but there +shouldn't be any ambiguity here. **OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` -**OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that are not being implemented? We could then assert that either that parameter has not been provided, and if it has, it has the same value as the default. The goal here would be to either use all the args of a function in its implementation, or mark explicitly those that we don't use. - -**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a keyword-only argument. It is for this reason that, in PrimTorch, we were able to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). This is not the case in NumPy. In NumPy `out` is a positional arg that is often interleaved with other parameters. This is the reason why we use the `OutArray` label to mark these. We then implement the `out` semantics in the `@normalizer` wrapper in a generic way. - -**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two sets of functions that are particularly regular. For these functions, we implement (some of) their args in a generic way. We then simply forward the computations to PyTorch, perhaps working around some PyTorch limitations. +**OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that +are not being implemented? We could then assert that either that parameter has +not been provided, and if it has, it has the same value as the default. The +goal here would be to either use all the args of a function in its +implementation, or mark explicitly those that we don't use. + +**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a +keyword-only argument. It is for this reason that, in PrimTorch, we were able +to implement it as +[a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). +This is not the case in NumPy. In NumPy `out` is a positional arg that is often +interleaved with other parameters. This is the reason why we use the `OutArray` +label to mark these. We then implement the `out` semantics in the `@normalizer` +wrapper in a generic way. + +**Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two +sets of functions that are particularly regular. For these functions, we +implement (some of) their args in a generic way. We then simply forward the +computations to PyTorch, perhaps working around some PyTorch limitations. ### The `ndarray` class -Once we have all the free functions implemented, implementing an `ndarray` class is rather simple. We simply register all the free functions as methods or dunder methods apropriately. We also forward the properties to the properties within the PyTorch tensor and we are done. - -### DTypes -**Default dtypes**. One of the issues that most often user when moving their codebase from NumPy to JAX was the default dtype changing from `float64` to `float32`. So much so, that this is one noted as one of [JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). Following the spirit of making everything match NumPy by default, we choose the NumPy defaults whenever the `dtype` was not chosen in a factory function. -**TODO(Lezcano)**: I just realised that we do not have a clean way to change the default dtype of `torch.numpy` to those from PyTorch. We should implement that utility flag, similar to [`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). Perhaps call it `torch.numpy.use_torch_defaults()` and then add a way for users to be able to set their own int/float/complex defaults. -**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also use them anywhere else -> Check +Once we have all the free functions implemented, implementing an `ndarray` +class is rather simple. We simply register all the free functions as methods or +dunder methods appropriately. We also forward the properties to the properties +within the PyTorch tensor and we are done. -**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or `np.longdouble`. Upon closer inspection, one finds that it also has [NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. NumPy scalars are similar to Python scalars but with a set width. NumPy scalars are NumPy's preferred return class for reductions and other operations that return just one element. NumPy scalars do not play particularly well with computations on devices like GPUs, as they live on CPU. Implementing NumPy scalars would mean that we need to synchronize after every `sum()` call, which is less-than-good. Instead, whenever a NumPy scalar would be returned, we will return a 0-D tensor, as PyTorch does. +### DTypes -**Type promotion**. Another not-so-well-known fact of NumPy's cast system is that it is data-dependent. Python scalars can be used in pretty much any NumPy operation, being able to call any operation that accepts a 0-D array with a Python scalar. If you provide an operation with a Python scalar, these will be casted to the smallest dtype that can represent them, and then they will participate in type promotion, allowing for some rather intersting behaviour -``` +**Default dtypes**. One of the issues that most often user when moving their +codebase from NumPy to JAX was the default dtype changing from `float64` to +`float32`. So much so, that this is one noted as one of +[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). +Following the spirit of making everything match NumPy by default, we choose the +NumPy defaults whenever the `dtype` was not chosen in a factory function. + +**TODO(Lezcano)**: I just realised that we do not have a clean way to change +the default dtype of `torch.numpy` to those from PyTorch. We should implement +that utility flag, similar to +[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). +Perhaps call it `torch.numpy.use_torch_defaults()` and then add a way for users +to be able to set their own int/float/complex defaults. +**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also +use them anywhere else -> Check + +**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks +quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or +`np.longdouble`. Upon closer inspection, one finds that it also has +[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. +NumPy scalars are similar to Python scalars but with a set width. NumPy scalars +are NumPy's preferred return class for reductions and other operations that +return just one element. NumPy scalars do not play particularly well with +computations on devices like GPUs, as they live on CPU. Implementing NumPy +scalars would mean that we need to synchronize after every `sum()` call, which +is less-than-good. Instead, whenever a NumPy scalar would be returned, we will +return a 0-D tensor, as PyTorch does. + +**Type promotion**. Another not-so-well-known fact of NumPy's cast system is +that it is data-dependent. Python scalars can be used in pretty much any NumPy +operation, being able to call any operation that accepts a 0-D array with a +Python scalar. If you provide an operation with a Python scalar, these will be +casted to the smallest dtype that can represent them, and then they will +participate in type promotion, allowing for some rather interesting behaviour +```python >>> np.asarray([1], dtype=np.int8) + 127 array([128], dtype=int8) >>> np.asarray([1], dtype=np.int8) + 128 array([129], dtype=int16) ``` -This dependent type promotion will be deprecated NumPy 2.0, and will be replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). As such, to be forward-looking and for simplicity, we chose to implement the type promotion behaviour proposed in NEP 50, which is much closer to that of Pytorch. - -Note that the decision of going with NEP 50 complements the previous one of returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not participate in type promotion in NumPy (but will do in NumPy 2.0): -``` +This dependent type promotion will be deprecated NumPy 2.0, and will be +replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). +As such, to be forward-looking and for simplicity, we chose to implement the +type promotion behaviour proposed in NEP 50, which is much closer to that of +Pytorch. + +Note that the decision of going with NEP 50 complements the previous one of +returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not +participate in type promotion in NumPy (but will do in NumPy 2.0): +```python int64_0d_array = np.array(1, dtype=np.int64) np.result_type(np.int8, int64_0d_array) == np.int8 ``` ## Testing -The testing of the framework was done via ~~copying~~ vendoring tests from the NumPy test suit. Then, we would replace the NumPy imports for imports with `torch.numpy`. The failures on these tests were then triaged and discussed the priority of fixing each of them. -In the (near) future, we plan to get some real world examples and run them through the library, to test its coverage and correctness. +The testing of the framework was done via ~~copying~~ vendoring tests from the +NumPy test suit. Then, we would replace the NumPy imports for imports with +`torch.numpy`. The failures on these tests were then triaged and discussed the +priority of fixing each of them. + +In the (near) future, we plan to get some real world examples and run them +through the library, to test its coverage and correctness. ## Limitations -One of the known limitations of this approach is the efficiency in eager. Similar to PrimTorch, sometimes we needed to work around some limitations of PyTorch (e.g. support for some operations for `float16`) or some ways PyTorch deviates from NumPy by implementing things manually calling several `torch` operations. This, when executed in eager mode and, in particular, on CUDA devices, will result on a perf-hit. To alleviate this, we tried to dispatch NumPy functions to PyTorch functions with as few indirections as possible, to alleviate the number of kernels called when executed on eager mode. -There are some known limitations. Some of them are tracked in the second part of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). When landing all this, we will create a comprehensive document with the differences between NumPy and `torch.numpy`. +One of the known limitations of this approach is the efficiency in eager. +Similar to PrimTorch, sometimes we needed to work around some limitations of +PyTorch (e.g. support for some operations for `float16`) or some ways PyTorch +deviates from NumPy by implementing things manually calling several `torch` +operations. This, when executed in eager mode and, in particular, on CUDA +devices, will result on a perf-hit. To alleviate this, we tried to dispatch +NumPy functions to PyTorch functions with as few indirections as possible, to +alleviate the number of kernels called when executed on eager mode. + +There are some known limitations. Some of them are tracked in the second part +of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). +There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). +When landing all this, we will create a comprehensive document with the differences +between NumPy and `torch.numpy`. ## Beyond NumPy -**CUDA**. The current implementation has just been implemented and tested on CPU. We expect CUDA coverage to be as good as the coverage we have with CPU matching CUDA. In the NumPy-only example in the introduction, given that no explicit `device` kwarg is used anywhere in this module, CUDA execution could be turned on via `with torch.device('cuda'):`. In the PyTorch+NumPy example, if the original tensors are on GPU, the whole execution should be performed on the GPU. + +**CUDA**. The current implementation has just been implemented and tested on +CPU. We expect CUDA coverage to be as good as the coverage we have with CPU +matching CUDA. In the NumPy-only example in the introduction, given that no +explicit `device` kwarg is used anywhere in this module, CUDA execution could +be turned on via `with torch.device('cuda'):`. In the PyTorch+NumPy example, if +the original tensors are on GPU, the whole execution should be performed on the +GPU. **TODO(Lezcano)**. We should probably test CUDA on the tests. -**Gradients**. We have not tested gradient tracking either as we are still to find some good examples on which to test it, but it should be a simple corollary of all this effort. In the PyTorch+NumPy scenario, if the original tensors fed into the function do have `requires_grad=True`, the tensors will track the gradients of the internal implementation and then the user could differentiate through the NumPy code. We do not have a way to turn the `requires_grad` flag in the all-NumPy case. Note that this is expected as this would require exposing all the autograd machinery from PyTorch into the API. If a user wants to compute gradients in their program, we expect them to wrap it in a function and apply the PyTorch-NumPy approach. +**Gradients**. We have not tested gradient tracking either as we are still to +find some good examples on which to test it, but it should be a simple +corollary of all this effort. In the PyTorch+NumPy scenario, if the original +tensors fed into the function do have `requires_grad=True`, the tensors will +track the gradients of the internal implementation and then the user could +differentiate through the NumPy code. We do not have a way to turn the +`requires_grad` flag in the all-NumPy case. Note that this is expected as this +would require exposing all the autograd machinery from PyTorch into the API. If +a user wants to compute gradients in their program, we expect them to wrap it +in a function and apply the PyTorch-NumPy approach. **TODO(Lezcano)**. Picking up simple NumPy programs from the internet would be good for these autograd tests. - # Bindings to TorchDyamo + **TODO(Lezcano)**: The PR is not there yet cf. [#95849](https://github.com/pytorch/pytorch/pull/95849). From 7a5b98c0793bd4170f5c7f617c7d16e0ce851357 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 12 Apr 2023 11:30:26 +0000 Subject: [PATCH 03/12] Fix some review comments --- RFC.md | 271 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 150 insertions(+), 121 deletions(-) diff --git a/RFC.md b/RFC.md index cd578aab..fcdd2ef7 100644 --- a/RFC.md +++ b/RFC.md @@ -1,29 +1,27 @@ -# Summary +# A PyTorch - NumPy compatibility layer +**Authors:** +* @ev-br +* @lezcano +* @rgommers + +## Summary This RFC describes a proposal for a translation layer from NumPy into PyTorch. In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, -the `np`, `np.linalg`, `np.fft` modules, etc) using `torch.Tensor` and PyTorch -ops as backend. +the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor` +and PyTorch ops as backend. -The this project has two main goals: -1. Have a `torch.numpy` submodule, similar to `jax.numpy` that serves as a - drop-in replacement for NumPy when imported as `import torch.numpy as np`. -2. Have TorchDynamo understand and use this layer to be able to trace through - NumPy programs as if they were written in PyTorch -Two corollaries of this work should be: -1. Given NumPy code, one should be able to differentiate through it using - PyTorch's autograd engine -2. Given NumPy code, one should be able to execute it on CUDA +The this project has a main goal as per the +[initial design document](https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8): +1. Make TorchDynamo understand NumPy calls The work is being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). -# The Translation Layer -In this section we discuss the ideas behind design and implementation of the -translation layer from PyTorch to NumPy +## Motivation -## The two expected uses +### An introductory example Let's start with some examples. @@ -37,7 +35,7 @@ z = np.dot(x, y) w = z.sum() ``` -By changing the first line to `import torch.numpy as np`, the semantics of the +When we trace this program with the compat layer, the semantics of the program would stay the same, but the implementation would be equivalent to ```python @@ -78,7 +76,7 @@ t_results[0] = result # store the result in a torch.Tensor ``` This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements -the `__array__` method. For it to work with the compatibility layer, we would +the `__array__` method. For it to work manually with the compatibility layer, we would need to wrap and unwrap the inputs / outputs. This could be done modifying `fn` as @@ -90,13 +88,87 @@ def fn(x, y): return ret.tensor.numpy() ``` -Note that this wrapping / unwrapping process can be easily automated via a decorator. -Even more, if a user wants to use PyTorch as a backend in a code that mixes -PyTorch and NumPy, it will mostly be the case that it is because they want to -trace through that code. In that setting, TorchDynamo will be able to -automatically do the wrapping/unwrapping. +This process would be done automatically by TorchDynamo, so we would simply need to write +```python +@compile +def fn(x, y): + return np.multiply(x, y).sum() +``` + +### The observable behavior + +The two main idea driving the design of this compatibility layer were the following: + +1. The behavior of the layer should be as close to that of NumPy as possible +2. The layer follows NumPy master + +The following design decisions follow from these: + +**Default dtypes**. One of the issues that most often user when moving their +codebase from NumPy to JAX was the default dtype changing from `float64` to +`float32`. So much so, that this is one noted as one of +[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). +Following the spirit of making everything match NumPy by default, we choose the +NumPy defaults whenever the `dtype` was not chosen in a factory function. + +**TODO(Lezcano)**: I just realized that we do not have a clean way to change +the default dtype of `torch_np` to those from PyTorch. We should implement +that utility flag, similar to +[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). +Perhaps call it `torch_np.use_torch_defaults()` and then add a way for users +to be able to set their own int/float/complex defaults. +**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also +use them anywhere else -> Check + +**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks +quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or +`np.longdouble`. Upon closer inspection, one finds that it also has +[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. +NumPy scalars are similar to Python scalars but with a set width. NumPy scalars +are NumPy's preferred return class for reductions and other operations that +return just one element. NumPy scalars do not play particularly well with +computations on devices like GPUs, as they live on CPU. Implementing NumPy +scalars would mean that we need to synchronize after every `sum()` call, which +is less-than-good. Instead, whenever a NumPy scalar would be returned, we will +return a 0-D tensor, as PyTorch does. + +**Type promotion**. Another not-so-well-known fact of NumPy's cast system is +that it is data-dependent. Python scalars can be used in pretty much any NumPy +operation, being able to call any operation that accepts a 0-D array with a +Python scalar. If you provide an operation with a Python scalar, these will be +casted to the smallest dtype that can represent them, and then they will +participate in type promotion, allowing for some rather interesting behaviour +```python +>>> np.asarray([1], dtype=np.int8) + 127 +array([128], dtype=int8) +>>> np.asarray([1], dtype=np.int8) + 128 +array([129], dtype=int16) +``` +This dependent type promotion will be deprecated NumPy 2.0, and will be +replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). +As such, to be forward-looking and for simplicity, we chose to implement the +type promotion behaviour proposed in NEP 50, which is much closer to that of +Pytorch. + +Note that the decision of going with NEP 50 complements the previous one of +returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not +participate in type promotion in NumPy (but will do in NumPy 2.0): +```python +int64_0d_array = np.array(1, dtype=np.int64) +np.result_type(np.int8, int64_0d_array) == np.int8 +``` + +**Versioning**. It should be clear from the previous points that NumPy has a +fair amount of questionable and legacy pain points. As such, we decided that +rather than trying to fight these, we would declare that the compat layer +follows the behavior of Numpy's master. Given the stability of NumPy's API and +how battle-tested its main functions are, we do not expect this to become a big +maintenance burden. If anything, it should make our lives easier, as some parts +of NumPy will soon be simplified and we will not need to implement them, as +described above. -## The `torch.numpy` module + +## The `torch_np` module The bulk of the work went into implementing a system that allows us to implement NumPy operations in terms of those of PyTorch. The main design goals @@ -107,9 +179,9 @@ were We say *most* of NumPy's API, because NumPy's API is not only massive, but also there are parts of it which cannot be implemented in PyTorch. For example, -NumPy has support for arrays of strings, dates, and other `dtype`s that PyTorch -does not consider. Negative strides are other example. We put together a list -of things that are out of the scope of this project in the +NumPy has support for arrays of string, datetime, structured and other dtypes. +Negative strides are other example of a feature that is just out of the scope. +We put together a list of things that are out of the scope of this project in the [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). For the bulk of the functions, we started by prioritizing most common @@ -124,6 +196,11 @@ The second point of preserving NumPy semantics as much as possible will be used in the sequel to discuss some points like the default dtypes that are used throughout the implementation. +**Visibility of the module** For simplicity, this RFC assumes that the +`torch_np` module will not be public, as the decision for it to be made public +was met with different opinions. We discuss these in the "Unresolved Questions" +section. + ### Annotation-based preprocessing NumPy accepts virtually anything that smells like an array as input to its operators @@ -138,7 +215,7 @@ array([1, 2, 3, 4, 5, 6]) To implement NumPy in terms of PyTorch, for any operation we would need to put the inputs into tensors, perform the operations, and then wrap the tensor into -a `torch.numpy.ndarray` (more on this class later). +a `torch_np.ndarray` (more on this class later). To avoid all this code repetition, we implement the functions in two steps. @@ -173,16 +250,16 @@ gathering all the inputs at runtime and normalizing them according to their annotations. We currently have four annotations (and small variations of them): -- `ArrayLike`: The input can be a `torch.numpy.array`, a list of lists, a +- `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. -- `DTypeLike`: Takes a `torch.numpy` dtype and returns a PyTorch dtype. +- `DTypeLike`: Takes a `torch_np` dtype and returns a PyTorch dtype. - `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or an `ndarray`) and returns a tuple. -- `OutArray`: Asserts that the input is a `torch.numpy.ndarray`. This is used +- `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used to implement the `out` arg. Note that none of the code here makes use of NumPy. We are writing -`torch.numpy.ndarray` above to make more explicit our intents, but there +`torch_np.ndarray` above to make more explicit our intents, but there shouldn't be any ambiguity here. **OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` @@ -213,114 +290,66 @@ class is rather simple. We simply register all the free functions as methods or dunder methods appropriately. We also forward the properties to the properties within the PyTorch tensor and we are done. -### DTypes - -**Default dtypes**. One of the issues that most often user when moving their -codebase from NumPy to JAX was the default dtype changing from `float64` to -`float32`. So much so, that this is one noted as one of -[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). -Following the spirit of making everything match NumPy by default, we choose the -NumPy defaults whenever the `dtype` was not chosen in a factory function. - -**TODO(Lezcano)**: I just realised that we do not have a clean way to change -the default dtype of `torch.numpy` to those from PyTorch. We should implement -that utility flag, similar to -[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). -Perhaps call it `torch.numpy.use_torch_defaults()` and then add a way for users -to be able to set their own int/float/complex defaults. -**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also -use them anywhere else -> Check - -**NumPy scalars**. NumPy's type system is tricky. At first sight, it looks -quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or -`np.longdouble`. Upon closer inspection, one finds that it also has -[NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. -NumPy scalars are similar to Python scalars but with a set width. NumPy scalars -are NumPy's preferred return class for reductions and other operations that -return just one element. NumPy scalars do not play particularly well with -computations on devices like GPUs, as they live on CPU. Implementing NumPy -scalars would mean that we need to synchronize after every `sum()` call, which -is less-than-good. Instead, whenever a NumPy scalar would be returned, we will -return a 0-D tensor, as PyTorch does. - -**Type promotion**. Another not-so-well-known fact of NumPy's cast system is -that it is data-dependent. Python scalars can be used in pretty much any NumPy -operation, being able to call any operation that accepts a 0-D array with a -Python scalar. If you provide an operation with a Python scalar, these will be -casted to the smallest dtype that can represent them, and then they will -participate in type promotion, allowing for some rather interesting behaviour -```python ->>> np.asarray([1], dtype=np.int8) + 127 -array([128], dtype=int8) ->>> np.asarray([1], dtype=np.int8) + 128 -array([129], dtype=int16) -``` -This dependent type promotion will be deprecated NumPy 2.0, and will be -replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). -As such, to be forward-looking and for simplicity, we chose to implement the -type promotion behaviour proposed in NEP 50, which is much closer to that of -Pytorch. - -Note that the decision of going with NEP 50 complements the previous one of -returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not -participate in type promotion in NumPy (but will do in NumPy 2.0): -```python -int64_0d_array = np.array(1, dtype=np.int64) -np.result_type(np.int8, int64_0d_array) == np.int8 -``` - -## Testing +### Testing The testing of the framework was done via ~~copying~~ vendoring tests from the NumPy test suit. Then, we would replace the NumPy imports for imports with -`torch.numpy`. The failures on these tests were then triaged and discussed the +`torch_np`. The failures on these tests were then triaged and discussed the priority of fixing each of them. In the (near) future, we plan to get some real world examples and run them through the library, to test its coverage and correctness. -## Limitations - -One of the known limitations of this approach is the efficiency in eager. -Similar to PrimTorch, sometimes we needed to work around some limitations of -PyTorch (e.g. support for some operations for `float16`) or some ways PyTorch -deviates from NumPy by implementing things manually calling several `torch` -operations. This, when executed in eager mode and, in particular, on CUDA -devices, will result on a perf-hit. To alleviate this, we tried to dispatch -NumPy functions to PyTorch functions with as few indirections as possible, to -alleviate the number of kernels called when executed on eager mode. +### Limitations -There are some known limitations. Some of them are tracked in the second part -of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). +A number of known limitations are tracked in the second part of the +[OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). When landing all this, we will create a comprehensive document with the differences -between NumPy and `torch.numpy`. +between NumPy and `torch_np`. -## Beyond NumPy +### Beyond Plain NumPy -**CUDA**. The current implementation has just been implemented and tested on -CPU. We expect CUDA coverage to be as good as the coverage we have with CPU -matching CUDA. In the NumPy-only example in the introduction, given that no -explicit `device` kwarg is used anywhere in this module, CUDA execution could -be turned on via `with torch.device('cuda'):`. In the PyTorch+NumPy example, if -the original tensors are on GPU, the whole execution should be performed on the -GPU. +**GPU**. The current implementation has just been implemented and tested on +CPU. We expect GPU coverage to be as good as the coverage we have with CPU +matching GPU. If the original tensors are on GPU, the whole execution should +be performed on the GPU. **TODO(Lezcano)**. We should probably test CUDA on the tests. **Gradients**. We have not tested gradient tracking either as we are still to find some good examples on which to test it, but it should be a simple -corollary of all this effort. In the PyTorch+NumPy scenario, if the original -tensors fed into the function do have `requires_grad=True`, the tensors will -track the gradients of the internal implementation and then the user could -differentiate through the NumPy code. We do not have a way to turn the -`requires_grad` flag in the all-NumPy case. Note that this is expected as this -would require exposing all the autograd machinery from PyTorch into the API. If -a user wants to compute gradients in their program, we expect them to wrap it -in a function and apply the PyTorch-NumPy approach. +corollary of all this effort. If the original tensors fed into the function do +have `requires_grad=True`, the tensors will track the gradients of the internal +implementation and then the user could differentiate through the NumPy code. **TODO(Lezcano)**. Picking up simple NumPy programs from the internet would be good for these autograd tests. -# Bindings to TorchDyamo +### Bindings to TorchDyamo + +The bindings for NumPy at the TorchDynamo level are currently being developed at [#95849](https://github.com/pytorch/pytorch/pull/95849). + + +## Unresolved Questions + +A question was left open in the initial discussion. Should the module `torch_np` be publicly exposed as `torch.numpy` or not? + +A few arguments in favor of making it public: +* People could use it in their NumPy programs just by changing the import to + `import torch.numpy as np`. This could be a selling point similar to JAX's + `jax.numpy`, which could incentivize adoption. +* People would not need to use the whole PyTorch 2.0 stack to start using + PyTorch in their codebases + * See [this experiment in scikit-learn](https://github.com/scikit-learn/scikit-learn/pull/25956) + where they got a 7x speed-up on CPU on a layer just by using `torch.linalg`. +* Since the layer is rather thin and in pure Python, if there are bugs, + external contributors could easily help fixing them or extend the supported + functionality. -**TODO(Lezcano)**: The PR is not there yet cf. [#95849](https://github.com/pytorch/pytorch/pull/95849). +A few arguments against: +* The compat introduces a number of type conversions that may produce somewhat + slow code when used in eager mode. + * [Note] Keeping this in mind, we tried to use in the implementations as few + operators as possible, to make it reasonably fast in eager mode. +* Exposing `torch.numpy` would create a less performant secondary entry point + to many of the functions in PyTorch. This could be a trap for new users. From e84e63565a96be29de69fd4406280a2eaefc4952 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 12 Apr 2023 15:14:47 +0000 Subject: [PATCH 04/12] General improvements --- RFC.md | 178 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/RFC.md b/RFC.md index fcdd2ef7..dd761c66 100644 --- a/RFC.md +++ b/RFC.md @@ -16,14 +16,12 @@ The this project has a main goal as per the [initial design document](https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8): 1. Make TorchDynamo understand NumPy calls -The work is being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). +The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). ## Motivation -### An introductory example - -Let's start with some examples. +### Introductory examples Consider the following snippet: ```python @@ -46,8 +44,8 @@ z = torch.matmul(x, y) w = z.sum() ``` -Here we already see a couple differences between NumPy and PyTorch. The most -obvious one is that the default dtype in NumPy is `float64` rather than +Here, we can already spot a couple differences between NumPy and PyTorch. +The most obvious one is that the default dtype in NumPy is `float64` rather than `float32`. The less obvious is very sneakily hiding in the last line. ```python @@ -57,10 +55,10 @@ obvious one is that the default dtype in NumPy is `float64` rather than Reductions and similar operations in NumPy return the infamous NumPy scalars. We'll discuss these and other NumPy quirks and how we dealt with them in the -sequel. +[design decision section](#design-decisions). -As expected, this layer also allows for combining NumPy code and PyTorch code. +Let's now have a look at a toy example of how this layer would be used. ```python import torch import numpy as np @@ -75,41 +73,32 @@ t_results = torch.empty(5, dtype=torch.float64) t_results[0] = result # store the result in a torch.Tensor ``` -This code mixing NumPy and PyTorch already works, as `torch.Tensor` implements -the `__array__` method. For it to work manually with the compatibility layer, we would -need to wrap and unwrap the inputs / outputs. This could be done modifying `fn` -as - -```python -def fn(x, y): - x = np.asarray(x) - y = np.asarray(y) - ret = np.multiply(x, y).sum() - return ret.tensor.numpy() -``` +Note that this code mixing NumPy and PyTorch already works, as `torch.Tensor` +implements the `__array__` method. Now, the compatibility layer allows us to +trace through it. In order to do that, there would be no necessary changes, +other than simply ask `torch.compile` to trace through it: -This process would be done automatically by TorchDynamo, so we would simply need to write ```python @compile def fn(x, y): return np.multiply(x, y).sum() ``` -### The observable behavior +### Design decisions -The two main idea driving the design of this compatibility layer were the following: +The two main ideas driving the design of this compatibility layer are the following: 1. The behavior of the layer should be as close to that of NumPy as possible 2. The layer follows NumPy master The following design decisions follow from these: -**Default dtypes**. One of the issues that most often user when moving their -codebase from NumPy to JAX was the default dtype changing from `float64` to -`float32`. So much so, that this is one noted as one of +**Default dtypes**. One of the most common issues that bites people when migrating their +codebases from NumPy to JAX is the default dtype changing from `float64` to +`float32`. So much so that this is noted as one of [JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). Following the spirit of making everything match NumPy by default, we choose the -NumPy defaults whenever the `dtype` was not chosen in a factory function. +NumPy defaults whenever the `dtype` was not made explicit in a factory function. **TODO(Lezcano)**: I just realized that we do not have a clean way to change the default dtype of `torch_np` to those from PyTorch. We should implement @@ -121,23 +110,34 @@ to be able to set their own int/float/complex defaults. use them anywhere else -> Check **NumPy scalars**. NumPy's type system is tricky. At first sight, it looks -quite a bit like PyTorch's, but having a few more dtypes like `np.uint16` or -`np.longdouble`. Upon closer inspection, one finds that it also has +like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. +Upon closer inspection, one finds that it also has [NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. NumPy scalars are similar to Python scalars but with a set width. NumPy scalars are NumPy's preferred return class for reductions and other operations that return just one element. NumPy scalars do not play particularly well with computations on devices like GPUs, as they live on CPU. Implementing NumPy scalars would mean that we need to synchronize after every `sum()` call, which -is less-than-good. Instead, whenever a NumPy scalar would be returned, we will -return a 0-D tensor, as PyTorch does. +would be terrible performance-wise. In this implementation, we choose to represent +NumPy scalars as 0-D arrays. This may cause small divergences in some cases like + +```python +>>> np.int32(2) * [1, 2, 3] # scalar decays to a python int +[1, 2, 3, 1, 2, 3] + +>>> np.asarray(2) * [1, 2, 3] # zero-dim array is an array-like +array([2, 4, 6]) +``` + +but we don't expect these to pose a big issue in practice. Note that in this +implementation `torch_np.int32(2)` would return the same as `torch_np.asarray(2)`. **Type promotion**. Another not-so-well-known fact of NumPy's cast system is that it is data-dependent. Python scalars can be used in pretty much any NumPy operation, being able to call any operation that accepts a 0-D array with a Python scalar. If you provide an operation with a Python scalar, these will be -casted to the smallest dtype that can represent them, and then they will -participate in type promotion, allowing for some rather interesting behaviour +casted to the smallest dtype they can be represented in, and then, they will +participate in type promotion. This allows for for some rather interesting behaviour ```python >>> np.asarray([1], dtype=np.int8) + 127 array([128], dtype=int8) @@ -146,33 +146,36 @@ array([129], dtype=int16) ``` This dependent type promotion will be deprecated NumPy 2.0, and will be replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). -As such, to be forward-looking and for simplicity, we chose to implement the +For simplicity and to be forward-looking, we chose to implement the type promotion behaviour proposed in NEP 50, which is much closer to that of Pytorch. Note that the decision of going with NEP 50 complements the previous one of returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not -participate in type promotion in NumPy (but will do in NumPy 2.0): +participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50): ```python int64_0d_array = np.array(1, dtype=np.int64) np.result_type(np.int8, int64_0d_array) == np.int8 ``` **Versioning**. It should be clear from the previous points that NumPy has a -fair amount of questionable and legacy pain points. As such, we decided that -rather than trying to fight these, we would declare that the compat layer -follows the behavior of Numpy's master. Given the stability of NumPy's API and -how battle-tested its main functions are, we do not expect this to become a big -maintenance burden. If anything, it should make our lives easier, as some parts -of NumPy will soon be simplified and we will not need to implement them, as -described above. +fair amount of questionable and legacy pain points. It is for this reason that +we decided that rather than fighting these, we would declare that the compat +layer follows the behavior of Numpy's master (even, in some cases, of NumPy +2.0). Given the stability of NumPy's API and how battle-tested its main +functions are, we do not expect this to become a big maintenance burden. If +anything, it should make our lives easier, as some parts of NumPy will soon be +simplified, saving us the pain of having to implement all the pre-existing +corner-cases. + +For reference NumPy 2.0 is expected to land at the end of this year. ## The `torch_np` module The bulk of the work went into implementing a system that allows us to implement NumPy operations in terms of those of PyTorch. The main design goals -were +here were 1. Implement *most* of NumPy's API 2. Preserve NumPy semantics as much as possible @@ -180,30 +183,26 @@ were We say *most* of NumPy's API, because NumPy's API is not only massive, but also there are parts of it which cannot be implemented in PyTorch. For example, NumPy has support for arrays of string, datetime, structured and other dtypes. -Negative strides are other example of a feature that is just out of the scope. +Negative strides are other example of a feature that is just not supported in PyTorch. We put together a list of things that are out of the scope of this project in the [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). -For the bulk of the functions, we started by prioritizing most common -operations. Then, when bringing tests from the NumPy test suit and running -them, we would triage and prioritize how important was to fix each failure we -found. Iterating this process, we ended up with a small list of differences -between the NumPy and the PyTorch API which we sorted out by hand and finished -implementing. That list and the prioritization discussion can be found in -[the first few posts of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). - -The second point of preserving NumPy semantics as much as possible will be used -in the sequel to discuss some points like the default dtypes that are used -throughout the implementation. +For the bulk of the functions, we started by prioritizing the most common +operations. Then, when bringing tests from the NumPy test suit, we would triage +and prioritize how important was to fix each failure we found. Iterating this +process, we ended up with a small list of differences between the NumPy and the +PyTorch API which we prioritized by hand. That list and the prioritization +discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). **Visibility of the module** For simplicity, this RFC assumes that the `torch_np` module will not be public, as the decision for it to be made public -was met with different opinions. We discuss these in the "Unresolved Questions" -section. +was met with different opinions. +We discuss these in the section [unresolved questions](#unresolved-questions). ### Annotation-based preprocessing -NumPy accepts virtually anything that smells like an array as input to its operators +NumPy accepts virtually anything that smells like an array as an input + ```python >>> np.add(1, 3) 4 @@ -213,8 +212,8 @@ array([6., 7., 8.]) array([1, 2, 3, 4, 5, 6]) ``` -To implement NumPy in terms of PyTorch, for any operation we would need to put -the inputs into tensors, perform the operations, and then wrap the tensor into +To implement NumPy in terms of PyTorch, for any operation we would need to map +inputs into tensors, perform the operations, and then wrap the tensor into a `torch_np.ndarray` (more on this class later). To avoid all this code repetition, we implement the functions in two steps. @@ -223,31 +222,33 @@ First, we implement functions with the NumPy signature, but assuming that in place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) they simply accept `torch.Tensor` and PyTorch-land objects and return `torch.Tensor`s. For example, we would implement `np.diag` as + ```python def diag(v, k=0): return torch.diag(v, k) ``` + In this layer, if a NumPy function is composite (calls other NumPy functions internally), we can simply vendor its implementation, and have it call our PyTorch-land implementations of these functions. In other words, at this level, -functions are composable, as any set of functions implemented purely in -PyTorch. All these implementations are internal, and are not meant to be seen -or used by the final user. +functions are composable, as they are simply regular PyTorch functions. +All these implementations are internal, and are not meant to be seen or used +by the final user. The second step is then done via type annotations and a decorator. Each type -annotation has then a map NumPy-land -> PyTorch-land associated, that maps the -set of inputs accepted by NumPy for that argument into a PyTorch-land object -(think a `torch.Tensor` or a PyTorch dtype). For example, for `np.diag` we -would write +annotation has an associated function from NumPy-land into PyTorch-land. This +function converts the set of inputs accepted by NumPy for that argument into a +PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example, +for `np.diag` we would write + ```python def diag(v: ArrayLike, k=0): return torch.diag(v, k) ``` -Then, we would wrap these Python-land functions in a `normalizer` decorator and -expose them in the public `torch.np` module. This decorator is in charge of -gathering all the inputs at runtime and normalizing them according to their -annotations. +Then, we wrap these Python-land functions with a `normalizer` decorator and +expose them in the `torch_np` module. This decorator is in charge of gathering +all the inputs at runtime and normalizing them according to their annotations. We currently have four annotations (and small variations of them): - `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a @@ -258,9 +259,9 @@ We currently have four annotations (and small variations of them): - `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used to implement the `out` arg. -Note that none of the code here makes use of NumPy. We are writing -`torch_np.ndarray` above to make more explicit our intents, but there -shouldn't be any ambiguity here. +Note that none of the code in this implementation makes use of NumPy. We are +writing `torch_np.ndarray` above to make more explicit our intents, but there +shouldn't be any ambiguity. **OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` **OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that @@ -271,30 +272,28 @@ implementation, or mark explicitly those that we don't use. **Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a keyword-only argument. It is for this reason that, in PrimTorch, we were able -to implement it as -[a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). +to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). This is not the case in NumPy. In NumPy `out` is a positional arg that is often interleaved with other parameters. This is the reason why we use the `OutArray` -label to mark these. We then implement the `out` semantics in the `@normalizer` +annotation to mark these. We then implement the `out` semantics in the `@normalizer` wrapper in a generic way. **Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two sets of functions that are particularly regular. For these functions, we -implement (some of) their args in a generic way. We then simply forward the -computations to PyTorch, perhaps working around some PyTorch limitations. - -### The `ndarray` class +implement their args in a generic way as a preprocessing or postprocessing. -Once we have all the free functions implemented, implementing an `ndarray` -class is rather simple. We simply register all the free functions as methods or -dunder methods appropriately. We also forward the properties to the properties -within the PyTorch tensor and we are done. +**The ndarray class** Once we have all the free functions implemented as +functions form `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the +methods from the `ndarray` class is rather simple. We simply register all the +free functions as methods or dunder methods appropriately. We also forward the +properties to the properties within the PyTorch tensor and we are done. +This creates a circular dependency which we break with a local import. ### Testing The testing of the framework was done via ~~copying~~ vendoring tests from the -NumPy test suit. Then, we would replace the NumPy imports for imports with -`torch_np`. The failures on these tests were then triaged and discussed the +NumPy test suit. Then, we would replace the NumPy imports with `torch_np` +imports. The failures on these tests were then triaged and discussed the priority of fixing each of them. In the (near) future, we plan to get some real world examples and run them @@ -305,7 +304,7 @@ through the library, to test its coverage and correctness. A number of known limitations are tracked in the second part of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). -When landing all this, we will create a comprehensive document with the differences +When landing this RFC, we will create a comprehensive document with the differences between NumPy and `torch_np`. ### Beyond Plain NumPy @@ -332,7 +331,8 @@ The bindings for NumPy at the TorchDynamo level are currently being developed at ## Unresolved Questions -A question was left open in the initial discussion. Should the module `torch_np` be publicly exposed as `torch.numpy` or not? +A question was left open in the initial discussion. Should the module +`torch_np` be publicly exposed as `torch.numpy` or not? A few arguments in favor of making it public: * People could use it in their NumPy programs just by changing the import to From e3c492bc1549a473607edc2dfbefb6e4175d73d5 Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 17 Apr 2023 11:47:01 +0000 Subject: [PATCH 05/12] Address Evgeni's review --- RFC.md | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/RFC.md b/RFC.md index dd761c66..1367fb04 100644 --- a/RFC.md +++ b/RFC.md @@ -73,10 +73,11 @@ t_results = torch.empty(5, dtype=torch.float64) t_results[0] = result # store the result in a torch.Tensor ``` -Note that this code mixing NumPy and PyTorch already works, as `torch.Tensor` -implements the `__array__` method. Now, the compatibility layer allows us to -trace through it. In order to do that, there would be no necessary changes, -other than simply ask `torch.compile` to trace through it: +Note that this code mixing NumPy and PyTorch already works in eager mode with +CPU tensors, as `torch.Tensor` implements the `__array__` method. Now, the +compatibility layer allows us to trace through it. In order to do that, there +would be no necessary changes, other than simply ask `torch.compile` to trace +through it: ```python @compile @@ -89,7 +90,7 @@ def fn(x, y): The two main ideas driving the design of this compatibility layer are the following: 1. The behavior of the layer should be as close to that of NumPy as possible -2. The layer follows NumPy master +2. The layer follows the most recent NumPy release The following design decisions follow from these: @@ -129,8 +130,8 @@ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like array([2, 4, 6]) ``` -but we don't expect these to pose a big issue in practice. Note that in this -implementation `torch_np.int32(2)` would return the same as `torch_np.asarray(2)`. +but we don't expect these to pose a big issue in practice. Note that in the +proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`. **Type promotion**. Another not-so-well-known fact of NumPy's cast system is that it is data-dependent. Python scalars can be used in pretty much any NumPy @@ -161,15 +162,23 @@ np.result_type(np.int8, int64_0d_array) == np.int8 **Versioning**. It should be clear from the previous points that NumPy has a fair amount of questionable and legacy pain points. It is for this reason that we decided that rather than fighting these, we would declare that the compat -layer follows the behavior of Numpy's master (even, in some cases, of NumPy -2.0). Given the stability of NumPy's API and how battle-tested its main -functions are, we do not expect this to become a big maintenance burden. If -anything, it should make our lives easier, as some parts of NumPy will soon be -simplified, saving us the pain of having to implement all the pre-existing +layer follows the behavior of Numpy's most recent release (even, in some cases, +of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its +main functions are, we do not expect this to become a big maintenance burden. +If anything, it should make our lives easier, as some parts of NumPy will soon +be simplified, saving us the pain of having to implement all the pre-existing corner-cases. For reference NumPy 2.0 is expected to land at the end of this year. +**Randomness**. PyTorch and NumPy use different random number generation methods. +In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html) +with a `Generator` object which has sampling methods on it. The current compat. +layer does not implement this new API, as the default bit generator in NumPy is a +`PCG64`, while on PyTorch we use a `MT19937` on CPU and a `Philox`. From this, it +follows that this API will not give any reproducibility guarantees when it comes +to randomness. + ## The `torch_np` module @@ -188,7 +197,7 @@ We put together a list of things that are out of the scope of this project in th [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). For the bulk of the functions, we started by prioritizing the most common -operations. Then, when bringing tests from the NumPy test suit, we would triage +operations. Then, when bringing tests from the NumPy test suite, we would triage and prioritize how important was to fix each failure we found. Iterating this process, we ended up with a small list of differences between the NumPy and the PyTorch API which we prioritized by hand. That list and the prioritization @@ -201,7 +210,7 @@ We discuss these in the section [unresolved questions](#unresolved-questions). ### Annotation-based preprocessing -NumPy accepts virtually anything that smells like an array as an input +NumPy accepts virtually anything that smells like an array as an input. ```python >>> np.add(1, 3) @@ -212,6 +221,7 @@ array([6., 7., 8.]) array([1, 2, 3, 4, 5, 6]) ``` +NumPy calls all these objects `array_like` objects. To implement NumPy in terms of PyTorch, for any operation we would need to map inputs into tensors, perform the operations, and then wrap the tensor into a `torch_np.ndarray` (more on this class later). @@ -248,7 +258,8 @@ def diag(v: ArrayLike, k=0): Then, we wrap these Python-land functions with a `normalizer` decorator and expose them in the `torch_np` module. This decorator is in charge of gathering -all the inputs at runtime and normalizing them according to their annotations. +all the inputs at runtime and normalizing them (i.e., converting `torch_np` +objects to PyTorch counterparts) according to their annotations. We currently have four annotations (and small variations of them): - `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a From ac07b1fcaeda124d317b7766a6b9617df5838d4b Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 17 Apr 2023 11:47:53 +0000 Subject: [PATCH 06/12] Remove note on NotImplemented annotation --- RFC.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/RFC.md b/RFC.md index 1367fb04..6ba3b1f3 100644 --- a/RFC.md +++ b/RFC.md @@ -275,11 +275,6 @@ writing `torch_np.ndarray` above to make more explicit our intents, but there shouldn't be any ambiguity. **OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` -**OBS(Lezcano)**: Should we have a `NotImplementedType` to mark the args that -are not being implemented? We could then assert that either that parameter has -not been provided, and if it has, it has the same value as the default. The -goal here would be to either use all the args of a function in its -implementation, or mark explicitly those that we don't use. **Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a keyword-only argument. It is for this reason that, in PrimTorch, we were able From 575b4a756319a16cd95ccf7c707404140bf8add8 Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 17 Apr 2023 11:52:37 +0000 Subject: [PATCH 07/12] Remove note on default dtypes --- RFC.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/RFC.md b/RFC.md index 6ba3b1f3..92e6c4b6 100644 --- a/RFC.md +++ b/RFC.md @@ -107,8 +107,6 @@ that utility flag, similar to [`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). Perhaps call it `torch_np.use_torch_defaults()` and then add a way for users to be able to set their own int/float/complex defaults. -**TODO(Lezcano)**: Do we just use them just in factory functions, or do we also -use them anywhere else -> Check **NumPy scalars**. NumPy's type system is tricky. At first sight, it looks like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. From 2d86a3ff9b75b90ddf324c59c60409a9e5db9339 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 20 Apr 2023 14:31:36 +0000 Subject: [PATCH 08/12] Exceptions added --- RFC.md | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/RFC.md b/RFC.md index 92e6c4b6..bed3ce5e 100644 --- a/RFC.md +++ b/RFC.md @@ -87,13 +87,25 @@ def fn(x, y): ### Design decisions -The two main ideas driving the design of this compatibility layer are the following: +The main ideas driving the design of this compatibility layer are the following: -1. The behavior of the layer should be as close to that of NumPy as possible -2. The layer follows the most recent NumPy release +1. The goal is to transform valid NumPy programs into their equivalent PyTorch +2. The behavior of the layer should be as close to that of NumPy as possible +3. The layer follows the most recent NumPy release The following design decisions follow from these: +**A superset of NumPy**. Same as PyTorch has spotty support for `float16` on +CPU, and less-than-good support for `complex32`, NumPy has a number of +well-known edge-cases. The decision of translating just valid NumPy programs, +often allows us to implement a superset of the functionality of NumPy with more +predictable and consistent behavior than NumPy itself. + +**Exceptions may be different**. We avoid entirely modelling the exception +system in NumPy. As seen in the implementation of PrimTorch, modelling the +error cases of a given system is terribly difficult. We avoid this altogether +and we choose not to offer any guarantee here. + **Default dtypes**. One of the most common issues that bites people when migrating their codebases from NumPy to JAX is the default dtype changing from `float64` to `float32`. So much so that this is noted as one of @@ -101,13 +113,6 @@ codebases from NumPy to JAX is the default dtype changing from `float64` to Following the spirit of making everything match NumPy by default, we choose the NumPy defaults whenever the `dtype` was not made explicit in a factory function. -**TODO(Lezcano)**: I just realized that we do not have a clean way to change -the default dtype of `torch_np` to those from PyTorch. We should implement -that utility flag, similar to -[`torch.set_default_dtype`](https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html). -Perhaps call it `torch_np.use_torch_defaults()` and then add a way for users -to be able to set their own int/float/complex defaults. - **NumPy scalars**. NumPy's type system is tricky. At first sight, it looks like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. Upon closer inspection, one finds that it also has @@ -130,6 +135,9 @@ array([2, 4, 6]) but we don't expect these to pose a big issue in practice. Note that in the proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`. +In general, we try to avoid unnecessary graph breaks whenever we can. For +example, we may choose to return a tensor of shape `(2, *)` rather than a list +of pairs, to avoid unnecessary graph breaks. **Type promotion**. Another not-so-well-known fact of NumPy's cast system is that it is data-dependent. Python scalars can be used in pretty much any NumPy @@ -326,8 +334,6 @@ corollary of all this effort. If the original tensors fed into the function do have `requires_grad=True`, the tensors will track the gradients of the internal implementation and then the user could differentiate through the NumPy code. -**TODO(Lezcano)**. Picking up simple NumPy programs from the internet would be good for these autograd tests. - ### Bindings to TorchDyamo The bindings for NumPy at the TorchDynamo level are currently being developed at [#95849](https://github.com/pytorch/pytorch/pull/95849). From b85a999791262d91a45783ebd58c7d7594f8a28b Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 21 Apr 2023 10:12:26 +0000 Subject: [PATCH 09/12] Remove TODOs --- RFC.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/RFC.md b/RFC.md index bed3ce5e..d43a1fac 100644 --- a/RFC.md +++ b/RFC.md @@ -280,8 +280,6 @@ Note that none of the code in this implementation makes use of NumPy. We are writing `torch_np.ndarray` above to make more explicit our intents, but there shouldn't be any ambiguity. -**OBS(Lezcano)**: `DTypeLike` should be `Optional[DTypeLike]` - **Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a keyword-only argument. It is for this reason that, in PrimTorch, we were able to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). @@ -326,8 +324,6 @@ CPU. We expect GPU coverage to be as good as the coverage we have with CPU matching GPU. If the original tensors are on GPU, the whole execution should be performed on the GPU. -**TODO(Lezcano)**. We should probably test CUDA on the tests. - **Gradients**. We have not tested gradient tracking either as we are still to find some good examples on which to test it, but it should be a simple corollary of all this effort. If the original tensors fed into the function do From 44ee780ac1c86a4498568ad55d37bba5e5eadc4c Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 28 Apr 2023 10:15:16 +0000 Subject: [PATCH 10/12] Address review comments --- RFC.md | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/RFC.md b/RFC.md index d43a1fac..27a11499 100644 --- a/RFC.md +++ b/RFC.md @@ -85,6 +85,10 @@ def fn(x, y): return np.multiply(x, y).sum() ``` +Then, TorchDynamo would will cast `x` and `y` to our internal implementation of `ndarray`, +and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch` +functions effectively turning this function into a pure PyTorch function. + ### Design decisions The main ideas driving the design of this compatibility layer are the following: @@ -112,6 +116,8 @@ codebases from NumPy to JAX is the default dtype changing from `float64` to [JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). Following the spirit of making everything match NumPy by default, we choose the NumPy defaults whenever the `dtype` was not made explicit in a factory function. +We also provide a function `set_default_dtype` that allows to change this behavior +dynamically. **NumPy scalars**. NumPy's type system is tricky. At first sight, it looks like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. @@ -123,7 +129,8 @@ return just one element. NumPy scalars do not play particularly well with computations on devices like GPUs, as they live on CPU. Implementing NumPy scalars would mean that we need to synchronize after every `sum()` call, which would be terrible performance-wise. In this implementation, we choose to represent -NumPy scalars as 0-D arrays. This may cause small divergences in some cases like +NumPy scalars as 0-D arrays. This may cause small divergences in some cases. For example, +consider the following NumPy behavior: ```python >>> np.int32(2) * [1, 2, 3] # scalar decays to a python int @@ -133,7 +140,7 @@ NumPy scalars as 0-D arrays. This may cause small divergences in some cases like array([2, 4, 6]) ``` -but we don't expect these to pose a big issue in practice. Note that in the +We don't expect these to pose a big issue in practice. Note that in the proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`. In general, we try to avoid unnecessary graph breaks whenever we can. For example, we may choose to return a tensor of shape `(2, *)` rather than a list @@ -151,7 +158,7 @@ array([128], dtype=int8) >>> np.asarray([1], dtype=np.int8) + 128 array([129], dtype=int16) ``` -This dependent type promotion will be deprecated NumPy 2.0, and will be +This data-dependent type promotion will be deprecated NumPy 2.0, and will be replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). For simplicity and to be forward-looking, we chose to implement the type promotion behaviour proposed in NEP 50, which is much closer to that of @@ -270,7 +277,8 @@ objects to PyTorch counterparts) according to their annotations. We currently have four annotations (and small variations of them): - `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. -- `DTypeLike`: Takes a `torch_np` dtype and returns a PyTorch dtype. +- `DTypeLike`: Takes a `torch_np` dtype, and any other object that Numpy dtypes + accept (strings, typecodes...) and returns a PyTorch dtype. - `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or an `ndarray`) and returns a tuple. - `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used @@ -302,18 +310,19 @@ This creates a circular dependency which we break with a local import. ### Testing The testing of the framework was done via ~~copying~~ vendoring tests from the -NumPy test suit. Then, we would replace the NumPy imports with `torch_np` +NumPy test suite. Then, we would replace the NumPy imports with `torch_np` imports. The failures on these tests were then triaged and discussed the priority of fixing each of them. -In the (near) future, we plan to get some real world examples and run them -through the library, to test its coverage and correctness. +In the end, to have a last check that this tool was sound, we pulled five +examples of NumPy code from different sources and we run it with this library. +We were able to successfully the five examples successfully with close to no code changes. +You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop). ### Limitations A number of known limitations are tracked in the second part of the [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). -There are some more in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/86). When landing this RFC, we will create a comprehensive document with the differences between NumPy and `torch_np`. From 332a2d0456ce95b8fe9642bd6830c6ee28a13995 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Fri, 28 Apr 2023 13:22:46 +0100 Subject: [PATCH 11/12] Explicitly include the main goal in the abstract. --- RFC.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/RFC.md b/RFC.md index 27a11499..2b03b975 100644 --- a/RFC.md +++ b/RFC.md @@ -11,10 +11,15 @@ In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor` and PyTorch ops as backend. - -The this project has a main goal as per the -[initial design document](https://docs.google.com/document/d/1gdUDgZNbumFORRcUaZUVw790CtNYweAM20C1fbWMNd8): -1. Make TorchDynamo understand NumPy calls +The main goal is: **make TorchDynamo understand NumPy calls**. +This should enable an end user to combine code that uses the PyTorch API with +code that uses the NumPy API, in a way that allows TorchDynamo to understand +those function calls and build up an execution graph. To enable this, it is key +that there is a translation layer from NumPy to PyTorch function calls, which +TorchDynamo can use in order to build up its execution graph from PyTorch +functions/primitives only. For niche functions in NumPy that don’t have a +PyTorch equivalent, it’s okay to graph break and still call NumPy to execute +the function call. The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). From 0a75be41d7d456a1c24aec5c33a5652dceef56de Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Fri, 28 Apr 2023 13:56:37 +0100 Subject: [PATCH 12/12] A full copy-edit and fixes for a few minor inaccuracies --- RFC.md | 137 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 71 insertions(+), 66 deletions(-) diff --git a/RFC.md b/RFC.md index 2b03b975..2eb0cb0a 100644 --- a/RFC.md +++ b/RFC.md @@ -92,23 +92,24 @@ def fn(x, y): Then, TorchDynamo would will cast `x` and `y` to our internal implementation of `ndarray`, and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch` -functions effectively turning this function into a pure PyTorch function. +functions, effectively turning this function into a pure PyTorch function. ### Design decisions The main ideas driving the design of this compatibility layer are the following: -1. The goal is to transform valid NumPy programs into their equivalent PyTorch +1. The goal is to transform valid NumPy and mixed PyTorch-NumPy programs into + their equivalent PyTorch-only execution. 2. The behavior of the layer should be as close to that of NumPy as possible 3. The layer follows the most recent NumPy release The following design decisions follow from these: -**A superset of NumPy**. Same as PyTorch has spotty support for `float16` on -CPU, and less-than-good support for `complex32`, NumPy has a number of -well-known edge-cases. The decision of translating just valid NumPy programs, -often allows us to implement a superset of the functionality of NumPy with more -predictable and consistent behavior than NumPy itself. +**A superset of NumPy**. NumPy has a number of well-known edge-cases (as does +PyTorch, like spotty support for `float16` on CPU and `complex32` in general). +The decision to translate only valid NumPy programs, often allows us to +implement a superset of the functionality of NumPy with more predictable and +consistent behavior than NumPy itself has. **Exceptions may be different**. We avoid entirely modelling the exception system in NumPy. As seen in the implementation of PrimTorch, modelling the @@ -118,9 +119,9 @@ and we choose not to offer any guarantee here. **Default dtypes**. One of the most common issues that bites people when migrating their codebases from NumPy to JAX is the default dtype changing from `float64` to `float32`. So much so that this is noted as one of -[JAX's shap edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). +[JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). Following the spirit of making everything match NumPy by default, we choose the -NumPy defaults whenever the `dtype` was not made explicit in a factory function. +NumPy default dtype whenever the `dtype` was not made explicit in a factory function. We also provide a function `set_default_dtype` that allows to change this behavior dynamically. @@ -128,9 +129,10 @@ dynamically. like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. Upon closer inspection, one finds that it also has [NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. -NumPy scalars are similar to Python scalars but with a set width. NumPy scalars -are NumPy's preferred return class for reductions and other operations that -return just one element. NumPy scalars do not play particularly well with +NumPy scalars are similar to Python scalars but with a fixed precision and +array-like methods attached. NumPy scalars are NumPy's preferred return class +for reductions and other operations that return just one element. +NumPy scalars do not play particularly well with computations on devices like GPUs, as they live on CPU. Implementing NumPy scalars would mean that we need to synchronize after every `sum()` call, which would be terrible performance-wise. In this implementation, we choose to represent @@ -149,13 +151,13 @@ We don't expect these to pose a big issue in practice. Note that in the proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`. In general, we try to avoid unnecessary graph breaks whenever we can. For example, we may choose to return a tensor of shape `(2, *)` rather than a list -of pairs, to avoid unnecessary graph breaks. +of pairs, to avoid a graph break. -**Type promotion**. Another not-so-well-known fact of NumPy's cast system is -that it is data-dependent. Python scalars can be used in pretty much any NumPy +**Type promotion**. Another not-so-well-known fact of NumPy's dtype system and casting rules +is that it is data-dependent. Python scalars can be used in pretty much any NumPy operation, being able to call any operation that accepts a 0-D array with a Python scalar. If you provide an operation with a Python scalar, these will be -casted to the smallest dtype they can be represented in, and then, they will +cast to the smallest dtype they can be represented in, and only then will they participate in type promotion. This allows for for some rather interesting behaviour ```python >>> np.asarray([1], dtype=np.int8) + 127 @@ -163,11 +165,12 @@ array([128], dtype=int8) >>> np.asarray([1], dtype=np.int8) + 128 array([129], dtype=int16) ``` -This data-dependent type promotion will be deprecated NumPy 2.0, and will be -replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). +This data-dependent type promotion will be removed in NumPy 2.0 (planned for Dec'23), and will be +replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) +(already implemented in NumPy, it needs to be enabled via a private global switch now). For simplicity and to be forward-looking, we chose to implement the type promotion behaviour proposed in NEP 50, which is much closer to that of -Pytorch. +PyTorch. Note that the decision of going with NEP 50 complements the previous one of returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not @@ -178,24 +181,22 @@ np.result_type(np.int8, int64_0d_array) == np.int8 ``` **Versioning**. It should be clear from the previous points that NumPy has a -fair amount of questionable and legacy pain points. It is for this reason that +fair amount of questionable behavior and legacy pain points. It is for this reason that we decided that rather than fighting these, we would declare that the compat -layer follows the behavior of Numpy's most recent release (even, in some cases, +layer follows the behavior of NumPy's most recent release (even, in some cases, of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its main functions are, we do not expect this to become a big maintenance burden. If anything, it should make our lives easier, as some parts of NumPy will soon be simplified, saving us the pain of having to implement all the pre-existing corner-cases. -For reference NumPy 2.0 is expected to land at the end of this year. - **Randomness**. PyTorch and NumPy use different random number generation methods. In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html) -with a `Generator` object which has sampling methods on it. The current compat. -layer does not implement this new API, as the default bit generator in NumPy is a -`PCG64`, while on PyTorch we use a `MT19937` on CPU and a `Philox`. From this, it -follows that this API will not give any reproducibility guarantees when it comes -to randomness. +with a `Generator` object which has sampling methods on it. The current compat +layer does not implement this new API, as the default bit generator in NumPy is +`PCG64`, while on PyTorch we use `MT19937` on CPU and `Philox` on non-CPU devices. +From this, it follows that this API will not give any reproducibility +guarantees when it comes to randomness. ## The `torch_np` module @@ -210,20 +211,21 @@ here were We say *most* of NumPy's API, because NumPy's API is not only massive, but also there are parts of it which cannot be implemented in PyTorch. For example, NumPy has support for arrays of string, datetime, structured and other dtypes. -Negative strides are other example of a feature that is just not supported in PyTorch. +Negative strides are another example of a feature that is not supported in PyTorch. We put together a list of things that are out of the scope of this project in the [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). For the bulk of the functions, we started by prioritizing the most common -operations. Then, when bringing tests from the NumPy test suite, we would triage -and prioritize how important was to fix each failure we found. Iterating this -process, we ended up with a small list of differences between the NumPy and the -PyTorch API which we prioritized by hand. That list and the prioritization +operations. Then, when bringing tests from the NumPy test suite, we triaged +and prioritized how important it was to fix each failure we found. Doing this +iteratively, we ended up with a small list of differences between the NumPy and +PyTorch APIs, which we prioritized by hand. That list and the prioritization discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). **Visibility of the module** For simplicity, this RFC assumes that the -`torch_np` module will not be public, as the decision for it to be made public -was met with different opinions. +`torch_np` module will not be public, as the initial suggestion for it to be +made public was met with mixed opinions. This topic can be revisited in the +future if desired. We discuss these in the section [unresolved questions](#unresolved-questions). ### Annotation-based preprocessing @@ -261,7 +263,7 @@ internally), we can simply vendor its implementation, and have it call our PyTorch-land implementations of these functions. In other words, at this level, functions are composable, as they are simply regular PyTorch functions. All these implementations are internal, and are not meant to be seen or used -by the final user. +by the end user. The second step is then done via type annotations and a decorator. Each type annotation has an associated function from NumPy-land into PyTorch-land. This @@ -287,41 +289,43 @@ We currently have four annotations (and small variations of them): - `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or an `ndarray`) and returns a tuple. - `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used - to implement the `out` arg. + to implement the `out` keyword. Note that none of the code in this implementation makes use of NumPy. We are -writing `torch_np.ndarray` above to make more explicit our intents, but there +writing `torch_np.ndarray` above to make more explicit our intent, but there shouldn't be any ambiguity. -**Implmenting out**: In PyTorch, the `out` kwarg is, as the name says, a -keyword-only argument. It is for this reason that, in PrimTorch, we were able -to implement it as [a decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). -This is not the case in NumPy. In NumPy `out` is a positional arg that is often -interleaved with other parameters. This is the reason why we use the `OutArray` -annotation to mark these. We then implement the `out` semantics in the `@normalizer` -wrapper in a generic way. +**Implementing `out`**: In PyTorch, the `out` kwarg is a keyword-only argument. +It is for this reason that, in PrimTorch, we were able to implement it as [a +decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). +This is not the case in NumPy. In NumPy, `out` can be used both as a positional +and a keyword argument, and is often interleaved with other parameters. This is +the reason why we use the `OutArray` annotation to mark these. We then +implement the `out` semantics in the `@normalizer` wrapper in a generic way. **Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two sets of functions that are particularly regular. For these functions, we -implement their args in a generic way as a preprocessing or postprocessing. +implement support for their arguments in a generic way as a preprocessing or +postprocessing step. -**The ndarray class** Once we have all the free functions implemented as -functions form `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the +**The `ndarray` class** Once we have all the free functions implemented as +functions from `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the methods from the `ndarray` class is rather simple. We simply register all the free functions as methods or dunder methods appropriately. We also forward the -properties to the properties within the PyTorch tensor and we are done. -This creates a circular dependency which we break with a local import. +properties of `ndarray to the corresponding properties of `torch.Tensor` and we +are done. This creates a circular dependency which we break with a local +import. ### Testing The testing of the framework was done via ~~copying~~ vendoring tests from the NumPy test suite. Then, we would replace the NumPy imports with `torch_np` -imports. The failures on these tests were then triaged and discussed the -priority of fixing each of them. +imports. The failures on these tests were then triaged, and either fixed or marked +`xfail` depending on our assessment of the priority of implementing a fix. In the end, to have a last check that this tool was sound, we pulled five -examples of NumPy code from different sources and we run it with this library. -We were able to successfully the five examples successfully with close to no code changes. +examples of NumPy code from different sources and ran it with this library (eager mode execution). +We were able to run the five examples successfully with close to no code changes. You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop). ### Limitations @@ -331,25 +335,26 @@ A number of known limitations are tracked in the second part of the When landing this RFC, we will create a comprehensive document with the differences between NumPy and `torch_np`. -### Beyond Plain NumPy +### Beyond plain NumPy -**GPU**. The current implementation has just been implemented and tested on -CPU. We expect GPU coverage to be as good as the coverage we have with CPU -matching GPU. If the original tensors are on GPU, the whole execution should -be performed on the GPU. +**GPU**. The current implementation so far only been implemented and tested on +CPU. We expect GPU coverage to be as good as the coverage we have with CPU-GPU +matching tests in the PyTorch test suite. If the original tensors are on GPU, +the execution should be performed fully on GPU. **Gradients**. We have not tested gradient tracking either as we are still to find some good examples on which to test it, but it should be a simple -corollary of all this effort. If the original tensors fed into the function do +corollary of all this effort. If the original tensors fed into a function have `requires_grad=True`, the tensors will track the gradients of the internal -implementation and then the user could differentiate through the NumPy code. +implementation and then the user can differentiate through their NumPy code. -### Bindings to TorchDyamo +### Bindings to TorchDynamo -The bindings for NumPy at the TorchDynamo level are currently being developed at [#95849](https://github.com/pytorch/pytorch/pull/95849). +The bindings for NumPy at the TorchDynamo level are currently being developed in +[pytorch#95849](https://github.com/pytorch/pytorch/pull/95849). -## Unresolved Questions +## Unresolved questions A question was left open in the initial discussion. Should the module `torch_np` be publicly exposed as `torch.numpy` or not? @@ -369,7 +374,7 @@ A few arguments in favor of making it public: A few arguments against: * The compat introduces a number of type conversions that may produce somewhat slow code when used in eager mode. - * [Note] Keeping this in mind, we tried to use in the implementations as few - operators as possible, to make it reasonably fast in eager mode. + * [Note] Keeping this in mind, we tried to use as few operators as possible, + in the implementation, to make it reasonably fast in eager mode. * Exposing `torch.numpy` would create a less performant secondary entry point to many of the functions in PyTorch. This could be a trap for new users.