Skip to content

Commit 59922f9

Browse files
authored
add numa to improve cpu inference perf (#2330)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent cd9b15d commit 59922f9

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

Dockerfile_intel

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
106106
g++ \
107107
git \
108108
wget \
109-
cmake
109+
cmake \
110+
libnuma-dev
110111

111112
ENV HUGGINGFACE_HUB_CACHE=/data \
112113
HF_HUB_ENABLE_HF_TRANSFER=1 \
@@ -135,7 +136,7 @@ RUN conda install -c conda-forge gperftools mkl
135136
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
136137
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
137138
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
138-
RUN pip install triton
139+
RUN pip install triton numa
139140

140141
WORKDIR /usr/src
141142

@@ -147,16 +148,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update
147148

148149
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
149150

150-
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
151+
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
151152
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
152153
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
153154
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
154155
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
155-
ENV KMP_BLOCKTIME=1
156-
ENV KMP_TPAUSE=0
157-
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
158-
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
159-
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
160156

161157
# Install server
162158
COPY proto proto

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,36 @@ def get_sliding_windows() -> int:
7474
return SLIDING_WINDOW
7575

7676

77+
def init_cpu_threads_env(rank_id: int, world_size: int):
78+
import importlib.util
79+
80+
if importlib.util.find_spec("numa") is not None:
81+
import numa
82+
import psutil
83+
84+
nodes = numa.get_max_node() + 1
85+
rank_per_node = math.ceil(world_size / nodes)
86+
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
87+
node_id = int(rank_id / rank_per_node)
88+
rank_offset_per_node = rank_id % rank_per_node
89+
if os.getenv("OMP_NUM_THREADS") is None:
90+
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
91+
else:
92+
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
93+
if len(numa.get_membind()) == nodes:
94+
numa.set_membind([node_id])
95+
torch.set_num_threads(num_cpus_per_rank)
96+
if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True):
97+
cpu_start = num_cpus_per_rank * rank_offset_per_node
98+
numa.set_affinity(
99+
0,
100+
list(numa.node_to_cpus(node_id))[
101+
cpu_start : cpu_start + num_cpus_per_rank
102+
],
103+
)
104+
logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}")
105+
106+
77107
@dataclass
78108
class FlashCausalLMBatch(Batch):
79109
batch_id: int
@@ -854,6 +884,7 @@ def __init__(
854884
device = torch.device("cpu")
855885
# Float16 doesn't exist on target.
856886
dtype = torch.bfloat16 if dtype is None else dtype
887+
init_cpu_threads_env(rank_id=rank, world_size=world_size)
857888
else:
858889
raise NotImplementedError(f"{model_class} is only available on GPU")
859890

0 commit comments

Comments
 (0)