Skip to content

Commit 1f12917

Browse files
authored
Blogpost for the PyTorch blog (#174)
1 parent 8a30f29 commit 1f12917

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed

blogpost/post.md

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Compiling NumPy code into C++ or CUDA via `torch.compile`
2+
3+
Quansight engineers have implemented support for tracing through NumPy code via
4+
`torch.compile` in PyTorch 2.1. This feature leverages PyTorch's compiler to
5+
generate efficient fused vectorized code without having to modify your original
6+
NumPy code. Even more, it also allows for executing NumPy functions on CUDA
7+
just by running them through `torch.compile` under `torch.device("cuda")`!
8+
9+
In this post, we go over how to use this feature and give a few tips and tricks
10+
to make the most out of it.
11+
12+
13+
## Compiling NumPy code into Parallel C++
14+
15+
We take as our running example one step in a K-Means algorithm.
16+
This piece of code is borrowed from this [NumPy book](https://realpython.com/numpy-array-programming/#clustering-algorithms)
17+
18+
```python
19+
import numpy as np
20+
21+
def get_labels(X, means):
22+
return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)
23+
```
24+
25+
We create a synthetic dataset with 10M random 2-D points. We can see that,
26+
given that the means are chosen appropriately, the function returns the correct
27+
cluster for all of them
28+
29+
```python
30+
npts = 10_000_000
31+
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
32+
X = X + np.random.randn(*X.shape) # 2 distinct "blobs"
33+
means = np.array([[5, 5], [10, 10]])
34+
np_pred = get_labels(X, means)
35+
```
36+
37+
Benchmarking this function gives us a baseline of **1.26s** on an AMD 3970X CPU.
38+
39+
Compiling this function is now as easy as wrapping it with `torch.compile` and
40+
executing it with the example inputs
41+
42+
```python
43+
import torch
44+
45+
compiled_fn = torch.compile(get_labels)
46+
torch_pred = compiled_fn(X, means)
47+
assert np.allclose(np_pred, torch_pred)
48+
```
49+
50+
The compiled function yields a 9x speed-up when running it on 1 core. Even
51+
better, as opposed to NumPy, our generated code does take advantage of all the
52+
cores in a processor. As such, when we run it on 32 cores, we get a **57x
53+
speed-up**. Note that PyTorch always uses all the available cores unless
54+
explicitly restricted, so this is the default behavior you get when using
55+
`torch.compile`.
56+
57+
We may inspect the generated C++ code by running the script with the
58+
environment variable `TORCH_LOGS=output_code`. When doing so, we can see that
59+
`torch.compile` was able to compile the broadcasting and the two reductions
60+
into just one for-loop, and parallelize it using OpenMP
61+
62+
```c++
63+
extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
64+
#pragma omp parallel num_threads(32)
65+
#pragma omp for
66+
for(long i0=0L; i0<20000000L; i0+=1L) {
67+
auto tmp0 = in_ptr0[2L*i0];
68+
auto tmp1 = in_ptr1[0L];
69+
auto tmp5 = in_ptr0[1L + (2L*i0)];
70+
auto tmp6 = in_ptr1[1L];
71+
// Rest of the kernel omitted for brevity
72+
```
73+
74+
## Compiling NumPy code into CUDA
75+
76+
Compiling our code so that it runs on CUDA is as simple as setting the
77+
default device to be CUDA
78+
79+
```python
80+
with torch.device("cuda"):
81+
cuda_pred = compiled_fn(X, means)
82+
assert np.allclose(np_pred, cuda_pred)
83+
```
84+
85+
By inspecting the generated code via `TORCH_LOGS=output_code`, we see that,
86+
rather than generating CUDA code directly, `torch.compile` generates rather
87+
readable [triton](https://triton-lang.org/main/index.html) code
88+
89+
```python
90+
def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr):
91+
xnumel = 20000000
92+
xoffset = tl.program_id(0) * XBLOCK
93+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
94+
xmask = xindex < xnumel
95+
x0 = xindex
96+
tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
97+
tmp1 = tl.load(in_ptr1 + (0))
98+
// Rest of the kernel omitted for brevity
99+
```
100+
101+
Running this small snippet on an RTX 2060 gives an **8x speed-up** over the
102+
original NumPy code. This is something, but it is not particularly impressive,
103+
given the speed-ups we have seen on CPU. Let's have a look into how to squeeze
104+
the most out of our GPU via a couple minor changes.
105+
106+
**`float64` vs `float32`**. Many GPUs, in particular consumer-grade ones, are
107+
rather sluggish when running operations on `float64`. For this reason, changing
108+
the data generation to `float32`, the original NumPy code just gets a bit
109+
faster, about a 9%, but our CUDA code gets **40% faster**, yielding a **11x
110+
speed-up** over the plain NumPy code.
111+
112+
`torch.compile`, by default, respects the NumPy semantics, and as such, it uses
113+
`np.float64` as its default dtype for all its creation ops. As discussed, this
114+
can hinder performance, so it is possible to change this default by setting
115+
116+
```python
117+
from torch._dynamo import config
118+
config.numpy_default_float = "float32"
119+
```
120+
121+
**CPU <> CUDA copies**. An 11x speed-up is good, but it is not even close to
122+
the CPU numbers. This is caused by a small transformation that `torch.compile`
123+
does behind the scenes. The code above takes NumPy arrays and returns NumPy
124+
arrays. All of these arrays are on CPU, but the computations are performed on
125+
the GPU. This means that every time the function is called, `torch.compile` has
126+
to copy all these arrays from CPU to the GPU, and then copy the result back to
127+
CPU to preserve the original semantics. There is no native solution to this
128+
issue in NumPy, as NumPy does not have the notion of a `device`. That being
129+
said, we can work around it by creating a wrapper to this function so that it
130+
accepts PyTorch tensors and returns PyTorch tensors.
131+
132+
```python
133+
@torch.compile
134+
def tensor_fn(X, means):
135+
X, means = X.numpy(), means.numpy()
136+
ret = get_labels(X, means)
137+
return torch.from_numpy(ret)
138+
139+
def cuda_fn(X, means):
140+
with torch.device("cuda"):
141+
return tensor_fn(X, means)
142+
```
143+
144+
This function now takes tensors in CUDA memory and returns tensors in CUDA
145+
memory, but the function itself is written in NumPy! `torch.compile` uses the
146+
`numpy()` and the `from_numpy()` calls as hints, and optimizes them away, and
147+
internally it simply works with PyTorch tensors without moving the memory at
148+
all. When we keep the tensors in CUDA and perform the computations in
149+
`float32`, we see a **200x speed-up** over the initial NumPy implementation on
150+
`float32` arrays.
151+
152+
**Mixing NumPy and PyTorch**. In this example, we had to write a small adaptor
153+
to convert tensors to ndarrays and then back to tensors. In programs that mix
154+
PyTorch and NumPy converting a tensor into an ndarray is often implemented as
155+
`x.detach().cpu().numpy()`, or simply `x.numpy(force=True)`. Since when running
156+
under `torch.compile` we can run NumPy code in CUDA, we can implement this
157+
conversion pattern as call to `x.numpy()`, as we did above. Doing so and
158+
running the resulting code under `device("cuda")` will generate efficient CUDA
159+
code from original NumPy calls without copying the data from CUDA to CPU at
160+
all. Note that the resulting code does not run without `torch.compile`. For it
161+
to run in eager mode one would need to rollback to `x.numpy(force=True)`.
162+
163+
## Further Speed-up tricks
164+
165+
**General advice**. The CUDA code we have shown is already quite efficient, but
166+
it is true that the running example is rather short. When dealing with larger
167+
programs, we may need to tweak parts of it to make it more efficient. A good
168+
place to start is the [`torch.compile` troubleshooting
169+
page](https://pytorch.org/docs/stable/dynamo/troubleshooting.html#performance-profiling).
170+
This showcases a number of ways to inspect the tracing process, and how to
171+
identify problematic code that may cause slowdowns.
172+
173+
**Advice when compiling NumPy code**. NumPy, even if rather similar to PyTorch,
174+
is often used very differently. It is rather common to perform computations in
175+
NumPy and then do an if/else depending on values within the array, or perform
176+
operations in-place, perhaps via boolean masks. These constructions, while
177+
supported by `torch.compile`, hamper its performance. Changes like moving from
178+
in-place indexing to using `np.where`, writing the code in a branchless way, or
179+
avoiding in-place ops in favor of out-of-place ops can go a long way.
180+
181+
To write fast NumPy code, it is best to avoid loops, but sometimes they are
182+
unavoidable. When tracing through a loop, `torch.compile` will try to fully
183+
unroll it. This is sometimes desirable, but sometimes it may not even be
184+
possible, like when we have a dynamic stopping condition, like in a while loop.
185+
In these cases, it may be best to just compile the body of the loop, perhaps a
186+
few iterations at a time (loop unrolling).
187+
188+
**Debugging NumPy code**. Debugging is rather tricky when a compiler is
189+
involved. To figure out whether an error you are hitting is a `torch.compile`
190+
error, or an error from the program, you can execute your NumPy program without
191+
`torch.compile` by replacing the NumPy import by `import torch._numpy as np`.
192+
This is should just be used for **debugging purposes** and is in no way a
193+
replacement for the PyTorch API, as it is **much slower** and, as a private API,
194+
**may change without notice**.
195+
196+
## Differences between NumPy and `torch.compile`d NumPy
197+
198+
**NumPy scalars**. NumPy returns NumPy scalars in almost any case where PyTorch
199+
would return a 0-D tensor (e.g. from `np.sum`). Under `torch.compile`, NumPy
200+
scalars are treated as 0-D arrays. This is just fine in most cases. The only
201+
case when their behavior diverges is when NumPy scalars are implicitly used as
202+
Python scalars. For example,
203+
204+
```python
205+
>>> np.asarray(2) * [1, 2, 3] # 0-D array is an array-like
206+
array([2, 4, 6])
207+
>>> u = np.int32(2)
208+
>>> u * [1, 2, 3] # scalar decays into a Python int
209+
[1, 2, 3, 1, 2, 3]
210+
>>> torch.compile(lambda: u * [1, 2, 3])()
211+
array([2, 4, 6]) # acts as a 0-D array, not as a scalar ?!?!
212+
```
213+
214+
If we compile the first two lines, we see that `torch.compile` treats `u` as a
215+
0-D array. To recover the eager semantics, we just need to make the casting
216+
explicit
217+
218+
```python
219+
>>> torch.compile(lambda: int(u) * [1, 2, 3])()
220+
[1, 2, 3, 1, 2, 3]
221+
```
222+
223+
**Type promotion and versioning**. NumPy's type promotion rules may be, at
224+
times, a bit surprising
225+
226+
```python
227+
>>> np.zeros(1, dtype=np.int8) + 127
228+
array([127], dtype=int8)
229+
>>> np.zeros(1, dtype=np.int8) + 128
230+
array([128], dtype=int16)
231+
```
232+
NumPy 2.0 is changing these rules to follow others that are closer to those
233+
PyTorch. The relevant technical document is [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
234+
`torch.compile` went ahead and implemented NEP 50 rather than the about-to-be-deprecated rules.
235+
236+
In general, `torch.compile` will match the semantics of the lastest NumPy release.
237+
238+
## Beyond NumPy: SciPy and scikit-learn
239+
240+
In parallel to this effort of making `torch.compile` understand NumPy code,
241+
other Quansight engineers have designed and proposed a way to support PyTorch
242+
tensors within scikit-learn and SciPy. This was received enthusiastically by
243+
other maintainers from these libraries, as it was shown that using PyTorch as a
244+
backend would often yield considerable speed-ups. Both projects have now merged
245+
initial support for PyTorch tensors across a number of APIs and submodules.
246+
247+
This sets the stepping stone to move towards a future where PyTorch tensors can
248+
be used within other libraries in the Python data ecosystem. Even more, this
249+
will enable running these other libraries on GPUs and even compiling code
250+
mixing these libraries and PyTorch, similar to what we have been discussed in
251+
this post.
252+
253+
If you want to learn more about this effort, how to use it, or how to help
254+
moving it forward, see this post. [TODO link post]
255+
256+
## Conclusion
257+
258+
PyTorch has committed since its inception to be a framework compatible with the
259+
rest of the Python ecosystem. Enabling compiling NumPy programs, and
260+
establishing the tools necessary to do the same for other prominent libraries
261+
are two more steps in this direction. Quansight and Meta continue working hand
262+
on hand, improving the compatibility between PyTorch and the rest of the
263+
ecosystem.
264+
265+
From Quansight, we would like to thank Mengwei, Voz, and Ed for their
266+
invaluable help in integrating our work with `torch.compile`. We would also
267+
like to thank Meta for funding this project as well as previous work on
268+
improving NumPy compatibility within PyTorch, and the project that led to
269+
supporting PyTorch within scikit-learn and SciPy. These are giant leaps towards
270+
consolidating PyTorch as the framework of choice within the open source Python
271+
data ecosystem.

0 commit comments

Comments
 (0)