Skip to content

Commit 9b56d3f

Browse files
feat: relax mistral requirements (#1351)
Close #1253 Close #1279
1 parent f3aea78 commit 9b56d3f

File tree

8 files changed

+671
-621
lines changed

8 files changed

+671
-621
lines changed

.github/workflows/build.yaml

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,50 @@ jobs:
146146
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
147147
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
148148

149+
integration-tests:
150+
concurrency:
151+
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
152+
cancel-in-progress: true
153+
needs:
154+
- start-runner
155+
- build-and-push-image # Wait for the docker image to be built
156+
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
157+
env:
158+
DOCKER_VOLUME: /cache
159+
steps:
160+
- uses: actions/checkout@v2
161+
- name: Inject slug/short variables
162+
uses: rlespinasse/[email protected]
163+
- name: Set up Python
164+
uses: actions/setup-python@v4
165+
with:
166+
python-version: 3.9
167+
- name: Tailscale
168+
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
169+
with:
170+
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
171+
- name: Prepare disks
172+
run: |
173+
sudo mkfs -t ext4 /dev/nvme1n1
174+
sudo mkdir ${{ env.DOCKER_VOLUME }}
175+
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
176+
- name: Install
177+
run: |
178+
make install-integration-tests
179+
- name: Run tests
180+
run: |
181+
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
182+
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
183+
pytest -s -vv integration-tests
184+
149185
build-and-push-image-rocm:
150186
concurrency:
151187
group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
152188
cancel-in-progress: true
153-
needs: start-runner # required to start the main job when the runner is ready
189+
needs:
190+
- start-runner
191+
- build-and-push-image # Wait for the main docker image to be built
192+
- integration-tests # Wait for the main integration-tests
154193
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
155194
permissions:
156195
contents: write
@@ -235,43 +274,6 @@ jobs:
235274
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
236275
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
237276

238-
integration-tests:
239-
concurrency:
240-
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
241-
cancel-in-progress: true
242-
needs:
243-
- start-runner
244-
- build-and-push-image # Wait for the docker image to be built
245-
- build-and-push-image-rocm
246-
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
247-
env:
248-
DOCKER_VOLUME: /cache
249-
steps:
250-
- uses: actions/checkout@v2
251-
- name: Inject slug/short variables
252-
uses: rlespinasse/[email protected]
253-
- name: Set up Python
254-
uses: actions/setup-python@v4
255-
with:
256-
python-version: 3.9
257-
- name: Tailscale
258-
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
259-
with:
260-
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
261-
- name: Prepare disks
262-
run: |
263-
sudo mkfs -t ext4 /dev/nvme1n1
264-
sudo mkdir ${{ env.DOCKER_VOLUME }}
265-
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
266-
- name: Install
267-
run: |
268-
make install-integration-tests
269-
- name: Run tests
270-
run: |
271-
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
272-
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
273-
pytest -s -vv integration-tests
274-
275277
stop-runner:
276278
name: Stop self-hosted EC2 runner
277279
needs:

server/poetry.lock

Lines changed: 587 additions & 508 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
1515
grpcio-reflection = "^1.51.1"
1616
grpc-interceptor = "^0.15.0"
1717
typer = "^0.6.1"
18-
accelerate = { version = "^0.20.0", optional = true }
18+
accelerate = { version = "^0.25.0", optional = true }
1919
bitsandbytes = { version = "^0.41.1", optional = true }
2020
safetensors = "^0.3.2"
2121
loguru = "^0.6.0"
@@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
2424
opentelemetry-instrumentation-grpc = "^0.36b0"
2525
hf-transfer = "^0.1.2"
2626
sentencepiece = "^0.1.97"
27-
tokenizers = "^0.13.3"
28-
huggingface-hub = "^0.16.4"
29-
transformers = "^4.32.1"
27+
tokenizers = "^0.15.0"
28+
huggingface-hub = "^0.19.3"
29+
transformers = "^4.36.1"
3030
einops = "^0.6.1"
3131
texttable = { version = "^1.6.7", optional = true }
3232
datasets = { version = "^2.14.0", optional = true }

server/requirements_cuda.txt

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
2-
bitsandbytes==0.41.2.post2 ; python_version >= "3.9" and python_version < "3.13"
2+
bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
33
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
44
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
55
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
@@ -8,14 +8,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
88
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
99
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
1010
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
11-
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
11+
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
1212
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
13-
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
14-
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
15-
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
13+
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
14+
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
15+
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
1616
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
17-
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
18-
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
17+
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
18+
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
1919
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
2020
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
2121
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -37,11 +37,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
3737
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
3838
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
3939
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
40-
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
40+
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
4141
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
42-
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
42+
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
4343
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
44-
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
44+
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
4545
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
4646
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
4747
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

server/requirements_rocm.txt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
77
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
88
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
99
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
10-
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
10+
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
1111
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
12-
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
13-
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
14-
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
12+
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
13+
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
14+
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
1515
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
16-
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
17-
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
16+
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
17+
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
1818
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
1919
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
2020
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -36,11 +36,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
3636
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
3737
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
3838
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
39-
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
39+
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
4040
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
41-
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
41+
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
4242
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
43-
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
43+
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
4444
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
4545
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
4646
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

server/text_generation_server/models/__init__.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,22 @@
5555
FlashSantacoderSharded,
5656
)
5757
from text_generation_server.models.idefics import IDEFICSSharded
58+
from text_generation_server.models.flash_mistral import FlashMistral
59+
from text_generation_server.models.flash_mixtral import FlashMixtral
60+
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
5861

5962
except ImportError as e:
6063
logger.warning(f"Could not import Flash Attention enabled models: {e}")
6164
FLASH_ATTENTION = False
65+
HAS_FLASH_ATTN_V2_CUDA = False
6266

6367
if FLASH_ATTENTION:
6468
__all__.append(FlashNeoXSharded)
6569
__all__.append(FlashRWSharded)
6670
__all__.append(FlashSantacoderSharded)
6771
__all__.append(FlashLlama)
6872
__all__.append(IDEFICSSharded)
69-
70-
MISTRAL = True
71-
try:
72-
from text_generation_server.models.flash_mistral import FlashMistral
73-
except ImportError as e:
74-
logger.warning(f"Could not import Mistral model: {e}")
75-
MISTRAL = False
76-
77-
if MISTRAL:
7873
__all__.append(FlashMistral)
79-
80-
MIXTRAL = True
81-
try:
82-
from text_generation_server.models.flash_mixtral import FlashMixtral
83-
except ImportError as e:
84-
logger.warning(f"Could not import Mixtral model: {e}")
85-
MIXTRAL = False
86-
87-
if MIXTRAL:
8874
__all__.append(FlashMixtral)
8975

9076

@@ -295,28 +281,28 @@ def get_model(
295281
)
296282

297283
if model_type == "mistral":
298-
if MISTRAL:
284+
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
285+
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
286+
):
299287
return FlashMistral(
300288
model_id,
301289
revision,
302290
quantize=quantize,
303291
dtype=dtype,
304292
trust_remote_code=trust_remote_code,
305293
)
306-
raise NotImplementedError("Mistral models requires flash attention v2")
307294

308295
if model_type == "mixtral":
309-
if MIXTRAL:
296+
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
297+
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
298+
):
310299
return FlashMixtral(
311300
model_id,
312301
revision,
313302
quantize=quantize,
314303
dtype=dtype,
315304
trust_remote_code=trust_remote_code,
316305
)
317-
raise NotImplementedError(
318-
"Mixtral models requires flash attention v2, stk and megablocks"
319-
)
320306

321307
if model_type == "opt":
322308
return OPTSharded(
@@ -348,17 +334,17 @@ def get_model(
348334
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
349335

350336
if sharded:
351-
raise ValueError("sharded is not supported for AutoModel")
337+
raise NotImplementedError("sharded is not supported for AutoModel")
352338
if quantize == "gptq":
353-
raise ValueError(
339+
raise NotImplementedError(
354340
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
355341
)
356342
if quantize == "awq":
357-
raise ValueError("awq quantization is not supported for AutoModel")
343+
raise NotImplementedError("awq quantization is not supported for AutoModel")
358344
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
359-
raise ValueError("4bit quantization is not supported for AutoModel")
345+
raise NotImplementedError("4bit quantization is not supported for AutoModel")
360346
elif quantize == "eetq":
361-
raise ValueError("Eetq quantization is not supported for AutoModel")
347+
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
362348
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
363349
return CausalLM(
364350
model_id,

server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@
2727
from typing import Optional, List, Tuple
2828

2929
from text_generation_server.utils import paged_attention, flash_attn
30-
from text_generation_server.utils.flash_attn import (
31-
attention,
32-
HAS_FLASH_ATTN_V2_ROCM,
33-
HAS_FLASH_ATTN_V2_CUDA,
34-
)
3530
from text_generation_server.utils.layers import (
3631
TensorParallelRowLinear,
3732
TensorParallelColumnLinear,
@@ -43,10 +38,6 @@
4338
)
4439

4540

46-
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
47-
raise ImportError("Mistral model requires flash attn v2")
48-
49-
5041
class MistralConfig(PretrainedConfig):
5142
model_type = "mistral"
5243

server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,9 @@
2727
from transformers.activations import ACT2FN
2828
from transformers.configuration_utils import PretrainedConfig
2929
from typing import Optional, List, Tuple
30+
from loguru import logger
3031

3132
from text_generation_server.utils import paged_attention, flash_attn
32-
from text_generation_server.utils.flash_attn import (
33-
HAS_FLASH_ATTN_V2_ROCM,
34-
HAS_FLASH_ATTN_V2_CUDA,
35-
)
3633
from text_generation_server.utils.layers import (
3734
FastLinear,
3835
FastRMSNorm,
@@ -44,18 +41,13 @@
4441
get_linear,
4542
)
4643

47-
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
48-
raise ImportError("Mixtral model requires flash attn v2")
49-
50-
try:
51-
import megablocks.ops as ops
52-
except ImportError:
53-
raise ImportError("Mixtral model requires megablocks to be installed")
54-
44+
HAS_MEGABLOCKS = True
5545
try:
5646
import stk
47+
import megablocks.ops as ops
5748
except ImportError:
58-
raise ImportError("Mixtral model requires stk to be installed")
49+
logger.warning("Mixtral: megablocks is not installed")
50+
HAS_MEGABLOCKS = False
5951

6052

6153
class MixtralConfig(PretrainedConfig):
@@ -590,7 +582,7 @@ def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
590582
return out
591583

592584
def forward(self, x: torch.Tensor) -> torch.Tensor:
593-
if len(x) > 256:
585+
if len(x) > 256 and HAS_MEGABLOCKS:
594586
return self.sparse_forward(x)
595587
# This is faster when there is not a lot of tokens
596588
return self.dense_forward(x)

0 commit comments

Comments
 (0)