Skip to content

Commit 55b1462

Browse files
Merge branch 'master' into suryans-feaure-store-collection-type
2 parents 85e1536 + b7a4792 commit 55b1462

File tree

12 files changed

+197
-23
lines changed

12 files changed

+197
-23
lines changed

CONTRIBUTING.md

+6
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ For example, see the [Processing API reference](https://github.com/aws/sagemaker
227227

228228
To build the Sphinx docs, run the following command in the `doc/` directory:
229229

230+
```shell
231+
# Initial setup, only required for the first run
232+
pip install -r requirements.txt
233+
pip install -e ../
234+
```
235+
230236
```shell
231237
make html
232238
```

src/sagemaker/estimator.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3285,7 +3285,6 @@ class Framework(EstimatorBase):
32853285
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = (
32863286
"2.0.1-gpu-py310-cu121",
32873287
"2.0-gpu-py310-cu121",
3288-
"2.1.0-gpu-py310",
32893288
)
32903289

32913290
def __init__(
@@ -3959,7 +3958,7 @@ def _distribution_configuration(self, distribution):
39593958
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
39603959
if (
39613960
unsupported_image in img_uri and not torch_distributed_enabled
3962-
): # disabling DLC images with CUDA12
3961+
): # disabling DLC images without SMDataParallel or SMModelParallel
39633962
raise ValueError(
39643963
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
39653964
"(Could be due to CUDA version being greater than 11.)"

src/sagemaker/image_uri_config/huggingface.json

+104-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
"4.12": "4.12.3",
1313
"4.17": "4.17.0",
1414
"4.26": "4.26.0",
15-
"4.28": "4.28.1"
15+
"4.28": "4.28.1",
16+
"4.36": "4.36.0"
1617
},
1718
"versions": {
1819
"4.4.2": {
@@ -970,6 +971,53 @@
970971
"gpu": "cu118-ubuntu20.04"
971972
}
972973
}
974+
},
975+
"4.36.0": {
976+
"version_aliases": {
977+
"pytorch2.1": "pytorch2.1.0"
978+
},
979+
"pytorch2.1.0": {
980+
"py_versions": [
981+
"py310"
982+
],
983+
"registries": {
984+
"af-south-1": "626614931356",
985+
"il-central-1": "780543022126",
986+
"ap-east-1": "871362719292",
987+
"ap-northeast-1": "763104351884",
988+
"ap-northeast-2": "763104351884",
989+
"ap-northeast-3": "364406365360",
990+
"ap-south-1": "763104351884",
991+
"ap-southeast-1": "763104351884",
992+
"ap-southeast-2": "763104351884",
993+
"ap-southeast-3": "907027046896",
994+
"ca-central-1": "763104351884",
995+
"cn-north-1": "727897471807",
996+
"cn-northwest-1": "727897471807",
997+
"eu-central-1": "763104351884",
998+
"eu-north-1": "763104351884",
999+
"eu-west-1": "763104351884",
1000+
"eu-west-2": "763104351884",
1001+
"eu-west-3": "763104351884",
1002+
"eu-south-1": "692866216735",
1003+
"me-south-1": "217643126080",
1004+
"me-central-1": "914824155844",
1005+
"sa-east-1": "763104351884",
1006+
"us-east-1": "763104351884",
1007+
"us-east-2": "763104351884",
1008+
"us-gov-east-1": "446045086412",
1009+
"us-gov-west-1": "442386744353",
1010+
"us-iso-east-1": "886529160074",
1011+
"us-isob-east-1": "094389454867",
1012+
"us-west-1": "763104351884",
1013+
"us-west-2": "763104351884",
1014+
"ca-west-1": "204538143572"
1015+
},
1016+
"repository": "huggingface-pytorch-training",
1017+
"container_version": {
1018+
"gpu": "cu121-ubuntu20.04"
1019+
}
1020+
}
9731021
}
9741022
}
9751023
},
@@ -985,7 +1033,8 @@
9851033
"4.12": "4.12.3",
9861034
"4.17": "4.17.0",
9871035
"4.26": "4.26.0",
988-
"4.28": "4.28.1"
1036+
"4.28": "4.28.1",
1037+
"4.37": "4.37.0"
9891038
},
9901039
"versions": {
9911040
"4.6.1": {
@@ -1782,7 +1831,59 @@
17821831
"cpu": "ubuntu20.04"
17831832
}
17841833
}
1834+
},
1835+
"4.37.0": {
1836+
"version_aliases": {
1837+
"pytorch2.1": "pytorch2.1.0"
1838+
},
1839+
"pytorch2.1.0": {
1840+
"py_versions": [
1841+
"py310"
1842+
],
1843+
"registries": {
1844+
"af-south-1": "626614931356",
1845+
"il-central-1": "780543022126",
1846+
"ap-east-1": "871362719292",
1847+
"ap-northeast-1": "763104351884",
1848+
"ap-northeast-2": "763104351884",
1849+
"ap-northeast-3": "364406365360",
1850+
"ap-south-1": "763104351884",
1851+
"ap-south-2": "772153158452",
1852+
"ap-southeast-1": "763104351884",
1853+
"ap-southeast-2": "763104351884",
1854+
"ap-southeast-3": "907027046896",
1855+
"ap-southeast-4": "457447274322",
1856+
"ca-central-1": "763104351884",
1857+
"cn-north-1": "727897471807",
1858+
"cn-northwest-1": "727897471807",
1859+
"eu-central-1": "763104351884",
1860+
"eu-central-2": "380420809688",
1861+
"eu-north-1": "763104351884",
1862+
"eu-west-1": "763104351884",
1863+
"eu-west-2": "763104351884",
1864+
"eu-west-3": "763104351884",
1865+
"eu-south-1": "692866216735",
1866+
"eu-south-2": "503227376785",
1867+
"me-south-1": "217643126080",
1868+
"me-central-1": "914824155844",
1869+
"sa-east-1": "763104351884",
1870+
"us-east-1": "763104351884",
1871+
"us-east-2": "763104351884",
1872+
"us-gov-east-1": "446045086412",
1873+
"us-gov-west-1": "442386744353",
1874+
"us-iso-east-1": "886529160074",
1875+
"us-isob-east-1": "094389454867",
1876+
"us-west-1": "763104351884",
1877+
"us-west-2": "763104351884",
1878+
"ca-west-1": "204538143572"
1879+
},
1880+
"repository": "huggingface-pytorch-inference",
1881+
"container_version": {
1882+
"gpu": "cu118-ubuntu20.04",
1883+
"cpu": "ubuntu22.04"
1884+
}
1885+
}
17851886
}
17861887
}
17871888
}
1788-
}
1889+
}

src/sagemaker/serve/builder/djl_builder.py

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(self):
8383
self.mode = None
8484
self.model_server = None
8585
self.image_uri = None
86+
self.image_config = None
87+
self.vpc_config = None
8688
self._original_deploy = None
8789
self.secret_key = None
8890
self.engine = None
@@ -138,6 +140,8 @@ def _create_djl_model(self) -> Type[Model]:
138140
"source_dir": code_dir,
139141
"env": self.env_vars,
140142
"hf_hub_token": self.env_vars.get("HUGGING_FACE_HUB_TOKEN"),
143+
"image_config": self.image_config,
144+
"vpc_config": self.vpc_config,
141145
}
142146

143147
if self.engine == _DjlEngine.DEEPSPEED:

src/sagemaker/serve/builder/model_builder.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from sagemaker.serve.validations.check_image_and_hardware_type import (
5555
validate_image_uri_and_hardware,
5656
)
57+
from sagemaker.workflow.entities import PipelineVariable
5758
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
5859

5960
logger = logging.getLogger(__name__)
@@ -81,7 +82,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
8182
8283
* ``Mode.SAGEMAKER_ENDPOINT``: Launch on a SageMaker endpoint
8384
* ``Mode.LOCAL_CONTAINER``: Launch locally with a container
84-
8585
shared_libs (List[str]): Any shared libraries you want to bring into
8686
the model packaging.
8787
dependencies (Optional[Dict[str, Any]): The dependencies of the model
@@ -122,6 +122,15 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
122122
``invoke`` and ``load`` functions.
123123
image_uri (Optional[str]): The container image uri (which is derived from a
124124
SageMaker-based container).
125+
image_config (dict[str, str] or dict[str, PipelineVariable]): Specifies
126+
whether the image of model container is pulled from ECR, or private
127+
registry in your VPC. By default it is set to pull model container
128+
image from ECR. (default: None).
129+
vpc_config ( Optional[Dict[str, List[Union[str, PipelineVariable]]]]):
130+
The VpcConfig set on the model (default: None)
131+
* 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids.
132+
* 'SecurityGroupIds' (List[Union[str, PipelineVariable]]]): List of security group
133+
ids.
125134
model_server (Optional[ModelServer]): The model server to which to deploy.
126135
You need to provide this argument when you specify an ``image_uri``
127136
in order for model builder to build the artifacts correctly (according
@@ -204,6 +213,23 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
204213
image_uri: Optional[str] = field(
205214
default=None, metadata={"help": "Define the container image uri"}
206215
)
216+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = field(
217+
default=None,
218+
metadata={
219+
"help": "Specifies whether the image of model container is pulled from ECR,"
220+
" or private registry in your VPC. By default it is set to pull model "
221+
"container image from ECR. (default: None)."
222+
},
223+
)
224+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = field(
225+
default=None,
226+
metadata={
227+
"help": "The VpcConfig set on the model (default: None)."
228+
"* 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids."
229+
"* ''SecurityGroupIds'' (List[Union[str, PipelineVariable]]): List of"
230+
" security group ids."
231+
},
232+
)
207233
model_server: Optional[ModelServer] = field(
208234
default=None, metadata={"help": "Define the model server to deploy to."}
209235
)
@@ -386,6 +412,8 @@ def _create_model(self):
386412
# TODO: we should create model as per the framework
387413
self.pysdk_model = Model(
388414
image_uri=self.image_uri,
415+
image_config=self.image_config,
416+
vpc_config=self.vpc_config,
389417
model_data=self.s3_upload_path,
390418
role=self.serve_settings.role_arn,
391419
env=self.env_vars,
@@ -543,15 +571,16 @@ def build(
543571
self,
544572
mode: Type[Mode] = None,
545573
role_arn: str = None,
546-
sagemaker_session: str = None,
574+
sagemaker_session: Optional[Session] = None,
547575
) -> Type[Model]:
548576
"""Create a deployable ``Model`` instance with ``ModelBuilder``.
549577
550578
Args:
551579
mode (Type[Mode], optional): The mode. Defaults to ``None``.
552580
role_arn (str, optional): The IAM role arn. Defaults to ``None``.
553-
sagemaker_session (str, optional): The SageMaker session to use
554-
for the execution. Defaults to ``None``.
581+
sagemaker_session (Optional[Session]): Session object which manages interactions
582+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
583+
function creates one using the default AWS configuration chain.
555584
556585
Returns:
557586
Type[Model]: A deployable ``Model`` object.
@@ -562,10 +591,7 @@ def build(
562591
self.mode = mode
563592
if role_arn:
564593
self.role_arn = role_arn
565-
if sagemaker_session:
566-
self.sagemaker_session = sagemaker_session
567-
elif not self.sagemaker_session:
568-
self.sagemaker_session = Session()
594+
self.sagemaker_session = sagemaker_session or Session()
569595

570596
self.sagemaker_session.settings._local_download_dir = self.model_path
571597

@@ -607,7 +633,7 @@ def save(
607633
self,
608634
save_path: Optional[str] = None,
609635
s3_path: Optional[str] = None,
610-
sagemaker_session: Optional[str] = None,
636+
sagemaker_session: Optional[Session] = None,
611637
role_arn: Optional[str] = None,
612638
) -> Type[Model]:
613639
"""WARNING: This function is expremental and not intended for production use.
@@ -618,7 +644,7 @@ def save(
618644
save_path (Optional[str]): The path where you want to save resources.
619645
s3_path (Optional[str]): The path where you want to upload resources.
620646
"""
621-
self.sagemaker_session = sagemaker_session if sagemaker_session else Session()
647+
self.sagemaker_session = sagemaker_session or Session()
622648

623649
if role_arn:
624650
self.role_arn = role_arn

src/sagemaker/serve/builder/tgi_builder.py

+4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def __init__(self):
7676
self.mode = None
7777
self.model_server = None
7878
self.image_uri = None
79+
self.image_config = None
80+
self.vpc_config = None
7981
self._original_deploy = None
8082
self.hf_model_config = None
8183
self._default_tensor_parallel_degree = None
@@ -134,6 +136,8 @@ def _create_tgi_model(self) -> Type[Model]:
134136

135137
pysdk_model = HuggingFaceModel(
136138
image_uri=self.image_uri,
139+
image_config=self.image_config,
140+
vpc_config=self.vpc_config,
137141
env=self.env_vars,
138142
role=self.role_arn,
139143
sagemaker_session=self.sagemaker_session,

src/sagemaker/serve/model_server/triton/triton_builder.py

+2
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ def _auto_detect_image_for_triton(self):
413413
def _create_triton_model(self) -> Type[Model]:
414414
self.pysdk_model = Model(
415415
image_uri=self.image_uri,
416+
image_config=self.image_config,
417+
vpc_config=self.vpc_config,
416418
model_data=self.s3_upload_path,
417419
role=self.serve_settings.role_arn,
418420
env=self.env_vars,

tests/unit/sagemaker/serve/builder/test_djl_builder.py

+13
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
LocalModelInvocationException,
3131
)
3232
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
33+
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
3334

3435
mock_model_id = "TheBloke/Llama-2-7b-chat-fp16"
3536
mock_t5_model_id = "google/flan-t5-xxl"
@@ -113,6 +114,8 @@ def test_build_deploy_for_djl_local_container(
113114
schema_builder=mock_schema_builder,
114115
mode=Mode.LOCAL_CONTAINER,
115116
model_server=ModelServer.DJL_SERVING,
117+
image_config=MOCK_IMAGE_CONFIG,
118+
vpc_config=MOCK_VPC_CONFIG,
116119
)
117120

118121
builder._prepare_for_mode = MagicMock()
@@ -132,6 +135,8 @@ def test_build_deploy_for_djl_local_container(
132135
assert builder._default_max_new_tokens == 256
133136
assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 256
134137
assert builder.nb_instance_type == "ml.g5.24xlarge"
138+
assert model.image_config == MOCK_IMAGE_CONFIG
139+
assert model.vpc_config == MOCK_VPC_CONFIG
135140
assert "deepspeed" in builder.image_uri
136141

137142
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
@@ -176,6 +181,8 @@ def test_build_for_djl_local_container_faster_transformer(
176181
schema_builder=mock_schema_builder,
177182
mode=Mode.LOCAL_CONTAINER,
178183
model_server=ModelServer.DJL_SERVING,
184+
image_config=MOCK_IMAGE_CONFIG,
185+
vpc_config=MOCK_VPC_CONFIG,
179186
)
180187
model = builder.build()
181188
builder.serve_settings.telemetry_opt_out = True
@@ -185,6 +192,8 @@ def test_build_for_djl_local_container_faster_transformer(
185192
model.generate_serving_properties()
186193
== mock_expected_fastertransformer_serving_properties
187194
)
195+
assert model.image_config == MOCK_IMAGE_CONFIG
196+
assert model.vpc_config == MOCK_VPC_CONFIG
188197
assert "fastertransformer" in builder.image_uri
189198

190199
@patch(
@@ -212,11 +221,15 @@ def test_build_for_djl_local_container_deepspeed(
212221
schema_builder=mock_schema_builder,
213222
mode=Mode.LOCAL_CONTAINER,
214223
model_server=ModelServer.DJL_SERVING,
224+
image_config=MOCK_IMAGE_CONFIG,
225+
vpc_config=MOCK_VPC_CONFIG,
215226
)
216227
model = builder.build()
217228
builder.serve_settings.telemetry_opt_out = True
218229

219230
assert isinstance(model, DeepSpeedModel)
231+
assert model.image_config == MOCK_IMAGE_CONFIG
232+
assert model.vpc_config == MOCK_VPC_CONFIG
220233
assert model.generate_serving_properties() == mock_expected_deepspeed_serving_properties
221234
assert "deepspeed" in builder.image_uri
222235

0 commit comments

Comments
 (0)