@@ -87,27 +87,32 @@ def fn(x, y):
87
87
88
88
### Design decisions
89
89
90
- The two main ideas driving the design of this compatibility layer are the following:
90
+ The main ideas driving the design of this compatibility layer are the following:
91
91
92
- 1 . The behavior of the layer should be as close to that of NumPy as possible
93
- 2 . The layer follows the most recent NumPy release
92
+ 1 . The goal is to transform valid NumPy programs into their equivalent PyTorch
93
+ 2 . The behavior of the layer should be as close to that of NumPy as possible
94
+ 3 . The layer follows the most recent NumPy release
94
95
95
96
The following design decisions follow from these:
96
97
98
+ ** A superset of NumPy** . Same as PyTorch has spotty support for ` float16 ` on
99
+ CPU, and less-than-good support for ` complex32 ` , NumPy has a number of
100
+ well-known edge-cases. The decision of translating just valid NumPy programs,
101
+ often allows us to implement a superset of the functionality of NumPy with more
102
+ predictable and consistent behavior than NumPy itself.
103
+
104
+ ** Exceptions may be different** . We avoid entirely modelling the exception
105
+ system in NumPy. As seen in the implementation of PrimTorch, modelling the
106
+ error cases of a given system is terribly difficult. We avoid this altogether
107
+ and we choose not to offer any guarantee here.
108
+
97
109
** Default dtypes** . One of the most common issues that bites people when migrating their
98
110
codebases from NumPy to JAX is the default dtype changing from ` float64 ` to
99
111
` float32 ` . So much so that this is noted as one of
100
112
[ JAX's shap edges] ( https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ) .
101
113
Following the spirit of making everything match NumPy by default, we choose the
102
114
NumPy defaults whenever the ` dtype ` was not made explicit in a factory function.
103
115
104
- ** TODO(Lezcano)** : I just realized that we do not have a clean way to change
105
- the default dtype of ` torch_np ` to those from PyTorch. We should implement
106
- that utility flag, similar to
107
- [ ` torch.set_default_dtype ` ] ( https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html ) .
108
- Perhaps call it ` torch_np.use_torch_defaults() ` and then add a way for users
109
- to be able to set their own int/float/complex defaults.
110
-
111
116
** NumPy scalars** . NumPy's type system is tricky. At first sight, it looks
112
117
like PyTorch's, but with few more dtypes like ` np.uint16 ` or ` np.longdouble ` .
113
118
Upon closer inspection, one finds that it also has
@@ -130,6 +135,9 @@ array([2, 4, 6])
130
135
131
136
but we don't expect these to pose a big issue in practice. Note that in the
132
137
proposed implementation ` np.int32(2) ` would return the same as ` np.asarray(2) ` .
138
+ In general, we try to avoid unnecessary graph breaks whenever we can. For
139
+ example, we may choose to return a tensor of shape ` (2, *) ` rather than a list
140
+ of pairs, to avoid unnecessary graph breaks.
133
141
134
142
** Type promotion** . Another not-so-well-known fact of NumPy's cast system is
135
143
that it is data-dependent. Python scalars can be used in pretty much any NumPy
@@ -326,8 +334,6 @@ corollary of all this effort. If the original tensors fed into the function do
326
334
have ` requires_grad=True ` , the tensors will track the gradients of the internal
327
335
implementation and then the user could differentiate through the NumPy code.
328
336
329
- ** TODO(Lezcano)** . Picking up simple NumPy programs from the internet would be good for these autograd tests.
330
-
331
337
### Bindings to TorchDyamo
332
338
333
339
The bindings for NumPy at the TorchDynamo level are currently being developed at [ #95849 ] ( https://github.com/pytorch/pytorch/pull/95849 ) .
0 commit comments