Skip to content

Commit dbceda5

Browse files
committed
feature: Add DJL Inference support
1 parent 849ed29 commit dbceda5

File tree

4 files changed

+1367
-0
lines changed

4 files changed

+1367
-0
lines changed
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.djl_inference.model import DJLPredictor # noqa: F401
17+
from sagemaker.djl_inference.model import DJLModel # noqa: F401
18+
from sagemaker.djl_inference.model import DeepSpeedModel # noqa: F401
19+
from sagemaker.djl_inference.model import HuggingFaceAccelerateModel # noqa: F401
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion"
17+
18+
DEEPSPEED_RECOMMENDED_ARCHITECTURES = {
19+
"bloom",
20+
"opt",
21+
"gpt_neox",
22+
"gptj",
23+
"gpt_neo",
24+
"gpt2",
25+
"xlm-roberta",
26+
"roberta",
27+
"bert",
28+
STABLE_DIFFUSION_MODEL_TYPE,
29+
}
30+
31+
DEEPSPEED_SUPPORTED_ARCHITECTURES = {
32+
"bloom",
33+
"opt",
34+
"gpt_neox",
35+
"gptj",
36+
"gpt_neo",
37+
"gpt2",
38+
"xlm-roberta",
39+
"roberta",
40+
"bert",
41+
STABLE_DIFFUSION_MODEL_TYPE,
42+
}
43+
44+
ALLOWED_INSTANCE_FAMILIES = {
45+
"ml.g4dn",
46+
"ml.g5",
47+
"ml.p3",
48+
"ml.p4",
49+
"local_gpu",
50+
}

0 commit comments

Comments
 (0)