|
| 1 | +# Compiling NumPy code into C++ or CUDA via `torch.compile` |
| 2 | + |
| 3 | +Quansight engineers have implemented support for tracing through NumPy code via |
| 4 | +`torch.compile` in PyTorch 2.1. This feature leverages PyTorch's compiler to |
| 5 | +generate efficient fused vectorized code without having to modify your original |
| 6 | +NumPy code. Even more, it also allows for executing NumPy functions on CUDA |
| 7 | +just by running them through `torch.compile` under `torch.device("cuda")`! |
| 8 | + |
| 9 | +In this post, we go over how to use this feature and give a few tips and tricks |
| 10 | +to make the most out of it. |
| 11 | + |
| 12 | + |
| 13 | +## Compiling NumPy code into Parallel C++ |
| 14 | + |
| 15 | +We take as our running example one step in a K-Means algorithm. |
| 16 | +This piece of code is borrowed from this [NumPy book](https://realpython.com/numpy-array-programming/#clustering-algorithms) |
| 17 | + |
| 18 | +```python |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +def get_labels(X, means): |
| 22 | + return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0) |
| 23 | +``` |
| 24 | + |
| 25 | +We create a synthetic dataset with 10M random 2-D points. We can see that, |
| 26 | +given that the means are chosen appropriately, the function returns the correct |
| 27 | +cluster for all of them |
| 28 | + |
| 29 | +```python |
| 30 | +npts = 10_000_000 |
| 31 | +X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0) |
| 32 | +X = X + np.random.randn(*X.shape) # 2 distinct "blobs" |
| 33 | +means = np.array([[5, 5], [10, 10]]) |
| 34 | +np_pred = get_labels(X, means) |
| 35 | +``` |
| 36 | + |
| 37 | +Benchmarking this function gives us a baseline of **1.26s** on an AMD 3970X CPU. |
| 38 | + |
| 39 | +Compiling this function is now as easy as wrapping it with `torch.compile` and |
| 40 | +executing it with the example inputs |
| 41 | + |
| 42 | +```python |
| 43 | +import torch |
| 44 | + |
| 45 | +compiled_fn = torch.compile(get_labels) |
| 46 | +torch_pred = compiled_fn(X, means) |
| 47 | +assert np.allclose(np_pred, torch_pred) |
| 48 | +``` |
| 49 | + |
| 50 | +The compiled function yields a 9x speed-up when running it on 1 core. Even |
| 51 | +better, as opposed to NumPy, our generated code does take advantage of all the |
| 52 | +cores in a processor. As such, when we run it on 32 cores, we get a **57x |
| 53 | +speed-up**. Note that PyTorch always uses all the available cores unless |
| 54 | +explicitly restricted, so this is the default behavior you get when using |
| 55 | +`torch.compile`. |
| 56 | + |
| 57 | +We may inspect the generated C++ code by running the script with the |
| 58 | +environment variable `TORCH_LOGS=output_code`. When doing so, we can see that |
| 59 | +`torch.compile` was able to compile the broadcasting and the two reductions |
| 60 | +into just one for-loop, and parallelize it using OpenMP |
| 61 | + |
| 62 | +```c++ |
| 63 | +extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) { |
| 64 | + #pragma omp parallel num_threads(32) |
| 65 | + #pragma omp for |
| 66 | + for(long i0=0L; i0<20000000L; i0+=1L) { |
| 67 | + auto tmp0 = in_ptr0[2L*i0]; |
| 68 | + auto tmp1 = in_ptr1[0L]; |
| 69 | + auto tmp5 = in_ptr0[1L + (2L*i0)]; |
| 70 | + auto tmp6 = in_ptr1[1L]; |
| 71 | + // Rest of the kernel omitted for brevity |
| 72 | +``` |
| 73 | +
|
| 74 | +## Compiling NumPy code into CUDA |
| 75 | +
|
| 76 | +Compiling our code so that it runs on CUDA is as simple as setting the |
| 77 | +default device to be CUDA |
| 78 | +
|
| 79 | +```python |
| 80 | +with torch.device("cuda"): |
| 81 | + cuda_pred = compiled_fn(X, means) |
| 82 | +assert np.allclose(np_pred, cuda_pred) |
| 83 | +``` |
| 84 | + |
| 85 | +By inspecting the generated code via `TORCH_LOGS=output_code`, we see that, |
| 86 | +rather than generating CUDA code directly, `torch.compile` generates rather |
| 87 | +readable [triton](https://triton-lang.org/main/index.html) code |
| 88 | + |
| 89 | +```python |
| 90 | +def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr): |
| 91 | + xnumel = 20000000 |
| 92 | + xoffset = tl.program_id(0) * XBLOCK |
| 93 | + xindex = xoffset + tl.arange(0, XBLOCK)[:] |
| 94 | + xmask = xindex < xnumel |
| 95 | + x0 = xindex |
| 96 | + tmp0 = tl.load(in_ptr0 + (2*x0), xmask) |
| 97 | + tmp1 = tl.load(in_ptr1 + (0)) |
| 98 | + // Rest of the kernel omitted for brevity |
| 99 | +``` |
| 100 | +
|
| 101 | +Running this small snippet on an RTX 2060 gives an **8x speed-up** over the |
| 102 | +original NumPy code. This is something, but it is not particularly impressive, |
| 103 | +given the speed-ups we have seen on CPU. Let's have a look into how to squeeze |
| 104 | +the most out of our GPU via a couple minor changes. |
| 105 | +
|
| 106 | +**`float64` vs `float32`**. Many GPUs, in particular consumer-grade ones, are |
| 107 | +rather sluggish when running operations on `float64`. For this reason, changing |
| 108 | +the data generation to `float32`, the original NumPy code just gets a bit |
| 109 | +faster, about a 9%, but our CUDA code gets **40% faster**, yielding a **11x |
| 110 | +speed-up** over the plain NumPy code. |
| 111 | +
|
| 112 | +`torch.compile`, by default, respects the NumPy semantics, and as such, it uses |
| 113 | +`np.float64` as its default dtype for all its creation ops. As discussed, this |
| 114 | +can hinder performance, so it is possible to change this default by setting |
| 115 | +
|
| 116 | +```python |
| 117 | +from torch._dynamo import config |
| 118 | +config.numpy_default_float = "float32" |
| 119 | +``` |
| 120 | + |
| 121 | +**CPU <> CUDA copies**. An 11x speed-up is good, but it is not even close to |
| 122 | +the CPU numbers. This is caused by a small transformation that `torch.compile` |
| 123 | +does behind the scenes. The code above takes NumPy arrays and returns NumPy |
| 124 | +arrays. All of these arrays are on CPU, but the computations are performed on |
| 125 | +the GPU. This means that every time the function is called, `torch.compile` has |
| 126 | +to copy all these arrays from CPU to the GPU, and then copy the result back to |
| 127 | +CPU to preserve the original semantics. There is no native solution to this |
| 128 | +issue in NumPy, as NumPy does not have the notion of a `device`. That being |
| 129 | +said, we can work around it by creating a wrapper to this function so that it |
| 130 | +accepts PyTorch tensors and returns PyTorch tensors. |
| 131 | + |
| 132 | +```python |
| 133 | +@torch.compile |
| 134 | +def tensor_fn(X, means): |
| 135 | + X, means = X.numpy(), means.numpy() |
| 136 | + ret = get_labels(X, means) |
| 137 | + return torch.from_numpy(ret) |
| 138 | + |
| 139 | +def cuda_fn(X, means): |
| 140 | + with torch.device("cuda"): |
| 141 | + return tensor_fn(X, means) |
| 142 | +``` |
| 143 | + |
| 144 | +This function now takes tensors in CUDA memory and returns tensors in CUDA |
| 145 | +memory, but the function itself is written in NumPy! `torch.compile` uses the |
| 146 | +`numpy()` and the `from_numpy()` calls as hints, and optimizes them away, and |
| 147 | +internally it simply works with PyTorch tensors without moving the memory at |
| 148 | +all. When we keep the tensors in CUDA and perform the computations in |
| 149 | +`float32`, we see a **200x speed-up** over the initial NumPy implementation on |
| 150 | +`float32` arrays. |
| 151 | + |
| 152 | +**Mixing NumPy and PyTorch**. In this example, we had to write a small adaptor |
| 153 | +to convert tensors to ndarrays and then back to tensors. In programs that mix |
| 154 | +PyTorch and NumPy converting a tensor into an ndarray is often implemented as |
| 155 | +`x.detach().cpu().numpy()`, or simply `x.numpy(force=True)`. Since when running |
| 156 | +under `torch.compile` we can run NumPy code in CUDA, we can implement this |
| 157 | +conversion pattern as call to `x.numpy()`, as we did above. Doing so and |
| 158 | +running the resulting code under `device("cuda")` will generate efficient CUDA |
| 159 | +code from original NumPy calls without copying the data from CUDA to CPU at |
| 160 | +all. Note that the resulting code does not run without `torch.compile`. For it |
| 161 | +to run in eager mode one would need to rollback to `x.numpy(force=True)`. |
| 162 | + |
| 163 | +## Further Speed-up tricks |
| 164 | + |
| 165 | +**General advice**. The CUDA code we have shown is already quite efficient, but |
| 166 | +it is true that the running example is rather short. When dealing with larger |
| 167 | +programs, we may need to tweak parts of it to make it more efficient. A good |
| 168 | +place to start is the [`torch.compile` troubleshooting |
| 169 | +page](https://pytorch.org/docs/stable/dynamo/troubleshooting.html#performance-profiling). |
| 170 | +This showcases a number of ways to inspect the tracing process, and how to |
| 171 | +identify problematic code that may cause slowdowns. |
| 172 | + |
| 173 | +**Advice when compiling NumPy code**. NumPy, even if rather similar to PyTorch, |
| 174 | +is often used very differently. It is rather common to perform computations in |
| 175 | +NumPy and then do an if/else depending on values within the array, or perform |
| 176 | +operations in-place, perhaps via boolean masks. These constructions, while |
| 177 | +supported by `torch.compile`, hamper its performance. Changes like moving from |
| 178 | +in-place indexing to using `np.where`, writing the code in a branchless way, or |
| 179 | +avoiding in-place ops in favor of out-of-place ops can go a long way. |
| 180 | + |
| 181 | +To write fast NumPy code, it is best to avoid loops, but sometimes they are |
| 182 | +unavoidable. When tracing through a loop, `torch.compile` will try to fully |
| 183 | +unroll it. This is sometimes desirable, but sometimes it may not even be |
| 184 | +possible, like when we have a dynamic stopping condition, like in a while loop. |
| 185 | +In these cases, it may be best to just compile the body of the loop, perhaps a |
| 186 | +few iterations at a time (loop unrolling). |
| 187 | + |
| 188 | +**Debugging NumPy code**. Debugging is rather tricky when a compiler is |
| 189 | +involved. To figure out whether an error you are hitting is a `torch.compile` |
| 190 | +error, or an error from the program, you can execute your NumPy program without |
| 191 | +`torch.compile` by replacing the NumPy import by `import torch._numpy as np`. |
| 192 | +This is should just be used for **debugging purposes** and is in no way a |
| 193 | +replacement for the PyTorch API, as it is **much slower** and, as a private API, |
| 194 | +**may change without notice**. |
| 195 | + |
| 196 | +## Differences between NumPy and `torch.compile`d NumPy |
| 197 | + |
| 198 | +**NumPy scalars**. NumPy returns NumPy scalars in almost any case where PyTorch |
| 199 | +would return a 0-D tensor (e.g. from `np.sum`). Under `torch.compile`, NumPy |
| 200 | +scalars are treated as 0-D arrays. This is just fine in most cases. The only |
| 201 | +case when their behavior diverges is when NumPy scalars are implicitly used as |
| 202 | +Python scalars. For example, |
| 203 | + |
| 204 | +```python |
| 205 | +>>> np.asarray(2) * [1, 2, 3] # 0-D array is an array-like |
| 206 | +array([2, 4, 6]) |
| 207 | +>>> u = np.int32(2) |
| 208 | +>>> u * [1, 2, 3] # scalar decays into a Python int |
| 209 | +[1, 2, 3, 1, 2, 3] |
| 210 | +>>> torch.compile(lambda: u * [1, 2, 3])() |
| 211 | +array([2, 4, 6]) # acts as a 0-D array, not as a scalar ?!?! |
| 212 | +``` |
| 213 | + |
| 214 | +If we compile the first two lines, we see that `torch.compile` treats `u` as a |
| 215 | +0-D array. To recover the eager semantics, we just need to make the casting |
| 216 | +explicit |
| 217 | + |
| 218 | +```python |
| 219 | +>>> torch.compile(lambda: int(u) * [1, 2, 3])() |
| 220 | +[1, 2, 3, 1, 2, 3] |
| 221 | +``` |
| 222 | + |
| 223 | +**Type promotion and versioning**. NumPy's type promotion rules may be, at |
| 224 | +times, a bit surprising |
| 225 | + |
| 226 | +```python |
| 227 | +>>> np.zeros(1, dtype=np.int8) + 127 |
| 228 | +array([127], dtype=int8) |
| 229 | +>>> np.zeros(1, dtype=np.int8) + 128 |
| 230 | +array([128], dtype=int16) |
| 231 | +``` |
| 232 | +NumPy 2.0 is changing these rules to follow others that are closer to those |
| 233 | +PyTorch. The relevant technical document is [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). |
| 234 | +`torch.compile` went ahead and implemented NEP 50 rather than the about-to-be-deprecated rules. |
| 235 | + |
| 236 | +In general, `torch.compile` will match the semantics of the lastest NumPy release. |
| 237 | + |
| 238 | +## Beyond NumPy: SciPy and scikit-learn |
| 239 | + |
| 240 | +In parallel to this effort of making `torch.compile` understand NumPy code, |
| 241 | +other Quansight engineers have designed and proposed a way to support PyTorch |
| 242 | +tensors within scikit-learn and SciPy. This was received enthusiastically by |
| 243 | +other maintainers from these libraries, as it was shown that using PyTorch as a |
| 244 | +backend would often yield considerable speed-ups. Both projects have now merged |
| 245 | +initial support for PyTorch tensors across a number of APIs and submodules. |
| 246 | + |
| 247 | +This sets the stepping stone to move towards a future where PyTorch tensors can |
| 248 | +be used within other libraries in the Python data ecosystem. Even more, this |
| 249 | +will enable running these other libraries on GPUs and even compiling code |
| 250 | +mixing these libraries and PyTorch, similar to what we have been discussed in |
| 251 | +this post. |
| 252 | + |
| 253 | +If you want to learn more about this effort, how to use it, or how to help |
| 254 | +moving it forward, see this post. [TODO link post] |
| 255 | + |
| 256 | +## Conclusion |
| 257 | + |
| 258 | +PyTorch has committed since its inception to be a framework compatible with the |
| 259 | +rest of the Python ecosystem. Enabling compiling NumPy programs, and |
| 260 | +establishing the tools necessary to do the same for other prominent libraries |
| 261 | +are two more steps in this direction. Quansight and Meta continue working hand |
| 262 | +on hand, improving the compatibility between PyTorch and the rest of the |
| 263 | +ecosystem. |
| 264 | + |
| 265 | +From Quansight, we would like to thank Mengwei, Voz, and Ed for their |
| 266 | +invaluable help in integrating our work with `torch.compile`. We would also |
| 267 | +like to thank Meta for funding this project as well as previous work on |
| 268 | +improving NumPy compatibility within PyTorch, and the project that led to |
| 269 | +supporting PyTorch within scikit-learn and SciPy. These are giant leaps towards |
| 270 | +consolidating PyTorch as the framework of choice within the open source Python |
| 271 | +data ecosystem. |
0 commit comments