diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index bc1b8891216..fd6c760b9f1 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -37,9 +37,9 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.35.2 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715c5e..5e6e8856622 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -248,7 +248,7 @@ def get_model( ) if model_type == "mistral": - if MISTRAL: + if FLASH_ATTENTION: return FlashMistral( model_id, revision, @@ -256,7 +256,14 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mistral model requires flash attention v2") + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "opt": return OPTSharded(