Skip to content

Commit 9cfab6f

Browse files
bhaozknikure
authored andcommitted
fix: refactor endpoint type enums, comments, docstrings, method names… (#1406)
1 parent f36644e commit 9cfab6f

File tree

14 files changed

+64
-59
lines changed

14 files changed

+64
-59
lines changed

src/sagemaker/base_predictor.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,9 @@ def update_endpoint(
399399
new_endpoint_config_name = name_from_base(current_endpoint_config_name)
400400

401401
if self._get_component_name():
402-
endpoint_type = EndpointType.GEN2
402+
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
403403
else:
404-
endpoint_type = EndpointType.GEN1
404+
endpoint_type = EndpointType.MODEL_BASED
405405

406406
self.sagemaker_session.create_endpoint_config_from_existing(
407407
current_endpoint_config_name,
@@ -442,8 +442,8 @@ def delete_endpoint(self, delete_endpoint_config=True):
442442
def delete_predictor(self, wait: bool = False) -> None:
443443
"""Delete the Amazon SageMaker inference component or endpoint backing this predictor.
444444
445-
Delete the corresponding inference component if the endpoint is a Generation2
446-
endpoint.
445+
Delete the corresponding inference component if the endpoint is a inference component
446+
based endpoint.
447447
Otherwise delete the endpoint where this predictor is hosted.
448448
"""
449449

@@ -485,8 +485,9 @@ def update_predictor(
485485
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
486486
(Default: None).
487487
resources (Optional[ResourceRequirements]): The compute resource requirements
488-
for a model to be deployed to an endpoint. Only EndpointType.GEN2 supports
489-
this feature. (Default: None).
488+
for a model to be deployed to an endpoint.
489+
Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
490+
(Default: None).
490491
"""
491492
if self.component_name is None:
492493
raise ValueError(

src/sagemaker/enums.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,7 @@
2424
class EndpointType(Enum):
2525
"""Types of endpoint"""
2626

27-
GEN1 = "gen1" # Amazon SageMaker Endpoint Generation 1
28-
GEN2 = "gen2" # Amazon SageMaker Endpoint Generation 2
27+
MODEL_BASED = "ModelBased" # Amazon SageMaker Model Based Endpoint
28+
INFERENCE_COMPONENT_BASED = (
29+
"InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
30+
)

src/sagemaker/jumpstart/factory/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def get_deploy_kwargs(
561561

562562
deploy_kwargs = _add_tags_to_kwargs(kwargs=deploy_kwargs)
563563

564-
if endpoint_type == EndpointType.GEN2:
564+
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
565565
deploy_kwargs = _add_resources_to_kwargs(kwargs=deploy_kwargs)
566566
deploy_kwargs.endpoint_type = endpoint_type
567567
deploy_kwargs.managed_instance_scaling = managed_instance_scaling

src/sagemaker/jumpstart/model.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,9 @@ def __init__(
263263
can be just the name if your account owns the Model Package.
264264
``model_data`` is not required. (Default: None).
265265
resources (Optional[ResourceRequirements]): The compute resource requirements
266-
for a model to be deployed to an endpoint. Only EndpointType.GEN2 supports
267-
this feature. (Default: None).
266+
for a model to be deployed to an endpoint.
267+
Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
268+
(Default: None).
268269
Raises:
269270
ValueError: If the model ID is not recognized by JumpStart.
270271
"""
@@ -460,7 +461,7 @@ def deploy(
460461
endpoint_logging: Optional[bool] = False,
461462
resources: Optional[ResourceRequirements] = None,
462463
managed_instance_scaling: Optional[str] = None,
463-
endpoint_type: EndpointType = EndpointType.GEN1,
464+
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
464465
) -> PredictorBase:
465466
"""Creates endpoint by calling base ``Model`` class `deploy` method.
466467
@@ -547,13 +548,14 @@ def deploy(
547548
endpoint_logging (Optiona[bool]): If set to true, live logging will be emitted as
548549
the SageMaker Endpoint starts up. (Default: False).
549550
resources (Optional[ResourceRequirements]): The compute resource requirements
550-
for a model to be deployed to an endpoint. Only EndpointType.GEN2 supports
551-
this feature. (Default: None).
551+
for a model to be deployed to an endpoint. Only
552+
EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
553+
(Default: None).
552554
managed_instance_scaling (Optional[Dict]): Managed intance scaling options,
553555
if configured Amazon SageMaker will manage the instance number behind the
554556
endpoint.
555557
endpoint_type (EndpointType): The type of endpoint used to deploy models.
556-
(Default: EndpointType.GEN1).
558+
(Default: EndpointType.MODEL_BASED).
557559
"""
558560

559561
deploy_kwargs = get_deploy_kwargs(

src/sagemaker/model.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,9 @@ def __init__(
312312
the SageMaker Python SDK attempts to use either the CodeCommit
313313
credential helper or local credential storage for authentication.
314314
resources (Optional[ResourceRequirements]): The compute resource requirements
315-
for a model to be deployed to an endpoint. Only EndpointType.GEN2 supports
316-
this feature. (Default: None).
315+
for a model to be deployed to an endpoint. Only
316+
EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
317+
(Default: None).
317318
318319
"""
319320
self.model_data = model_data
@@ -1275,7 +1276,7 @@ def deploy(
12751276
accept_eula: Optional[bool] = None,
12761277
endpoint_logging=False,
12771278
resources: Optional[ResourceRequirements] = None,
1278-
endpoint_type: EndpointType = EndpointType.GEN1,
1279+
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
12791280
managed_instance_scaling: Optional[str] = None,
12801281
**kwargs,
12811282
):
@@ -1367,13 +1368,13 @@ def deploy(
13671368
endpoint_logging (Optiona[bool]): If set to true, live logging will be emitted as
13681369
the SageMaker Endpoint starts up. (Default: False).
13691370
resources (Optional[ResourceRequirements]): The compute resource requirements
1370-
for a model to be deployed to an endpoint. Only EndpointType.GEN2 supports
1371-
this feature. (Default: None).
1371+
for a model to be deployed to an endpoint. Only
1372+
EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None).
13721373
managed_instance_scaling (Optional[Dict]): Managed instance scaling options,
13731374
if configured Amazon SageMaker will manage the instance number behind the
13741375
Endpoint. (Default: None).
13751376
endpoint_type (Optional[EndpointType]): The type of an endpoint used to deploy models.
1376-
(Default: EndpointType.GEN1).
1377+
(Default: EndpointType.MODEL_BASED).
13771378
Raises:
13781379
ValueError: If arguments combination check failed in these circumstances:
13791380
- If no role is specified or
@@ -1474,7 +1475,7 @@ def deploy(
14741475
self._base_name = "-".join((self._base_name, compiled_model_suffix))
14751476

14761477
# Support multiple models on same endpoint
1477-
if endpoint_type == EndpointType.GEN2:
1478+
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
14781479
if endpoint_name:
14791480
self.endpoint_name = endpoint_name
14801481
else:

src/sagemaker/session.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4069,7 +4069,7 @@ def create_endpoint_config_from_existing(
40694069
new_data_capture_config_dict=None,
40704070
new_production_variants=None,
40714071
new_explainer_config_dict=None,
4072-
endpoint_type=EndpointType.GEN1,
4072+
endpoint_type=EndpointType.MODEL_BASED,
40734073
):
40744074
"""Create an Amazon SageMaker endpoint configuration from an existing one.
40754075
@@ -4119,7 +4119,7 @@ def create_endpoint_config_from_existing(
41194119
production_variants = (
41204120
new_production_variants or existing_endpoint_config_desc["ProductionVariants"]
41214121
)
4122-
if endpoint_type == EndpointType.GEN2:
4122+
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
41234123
# Make a copy of Production variants and remove the InitialVariantWeight
41244124
# in the copy
41254125
copy_production_variants = deepcopy(production_variants)
@@ -5278,9 +5278,9 @@ def endpoint_from_production_variants(
52785278
sagemaker_config=load_sagemaker_config() if (self is None) else None,
52795279
)
52805280

5281-
# For Amazon SageMaker Generation 2 Endpoint, it will not pass Model names
5282-
# during Endpoint creation. Instead, ExecutionRoleArn will be needed in the
5283-
# EndpointConfig to create Endpoint
5281+
# For Amazon SageMaker inference component based endpoint, it will not pass
5282+
# Model names during endpoint creation. Instead, ExecutionRoleArn will be
5283+
# needed in the endpoint config to create Endpoint
52845284
model_names = [pv["ModelName"] for pv in production_variants if "ModelName" in pv]
52855285
if len(model_names) == 0:
52865286
# Currently, SageMaker Python SDK allow using RoleName to deploy models.

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_non_prepacked_jumpstart_model(setup):
6767
assert response is not None
6868

6969

70-
def test_non_prepacked_jumpstart_model_deployed_on_gen2_endpoint(setup):
70+
def test_non_prepacked_jumpstart_model_deployed_on_inference_component_based_endpoint(setup):
7171

7272
model_id = "huggingface-llm-falcon-7b-instruct-bf16" # default g5.2xlarge
7373

@@ -77,7 +77,7 @@ def test_non_prepacked_jumpstart_model_deployed_on_gen2_endpoint(setup):
7777
sagemaker_session=get_sm_session(),
7878
)
7979

80-
predictor = model.deploy(endpoint_type=EndpointType.GEN2)
80+
predictor = model.deploy(endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED)
8181

8282
inference_input = {
8383
"inputs": "Girafatron is obsessed with giraffes, the most glorious animal on the "

tests/integ/test_huggingface.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ def test_huggingface_inference(
177177

178178

179179
@pytest.mark.skip(
180-
reason="re-enable when above GEN1 endpoint hugging face inference test enabled",
180+
reason="re-enable when above MODEL_BASED endpoint hugging face inference test enabled",
181181
)
182-
def test_huggingface_inference_gen2_endpoint(
182+
def test_huggingface_inference_inference_component_based_endpoint(
183183
sagemaker_session,
184184
gpu_pytorch_instance_type,
185185
huggingface_inference_latest_version,
@@ -204,7 +204,7 @@ def test_huggingface_inference_gen2_endpoint(
204204
instance_type=gpu_pytorch_instance_type,
205205
initial_instance_count=1,
206206
endpoint_name=endpoint_name,
207-
endpoint_type=EndpointType.GEN2,
207+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
208208
resources=ResourceRequirements(
209209
requests={
210210
"num_accelerators": 1, # NumberOfCpuCoresRequired

tests/integ/test_generation_two_endpoint.py renamed to tests/integ/test_inference_component_based_endpoint.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_deploy_single_model_with_endpoint_name(tfs_model, resources):
120120
1,
121121
"ml.m5.large",
122122
endpoint_name=endpoint_name,
123-
endpoint_type=EndpointType.GEN2,
123+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
124124
resources=resources,
125125
)
126126

@@ -140,10 +140,7 @@ def test_deploy_single_model_with_endpoint_name(tfs_model, resources):
140140
predictor.delete_endpoint()
141141

142142

143-
@pytest.mark.slow_test
144-
@pytest.mark.skip(
145-
reason="Disable until us-west-2 production become stable",
146-
)
143+
@pytest.mark.release
147144
def test_deploy_update_predictor_with_other_model(
148145
tfs_model,
149146
resources,
@@ -155,7 +152,7 @@ def test_deploy_update_predictor_with_other_model(
155152
1,
156153
"ml.m5.4xlarge",
157154
endpoint_name=endpoint_name,
158-
endpoint_type=EndpointType.GEN2,
155+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
159156
resources=resources,
160157
)
161158

@@ -172,7 +169,7 @@ def test_deploy_update_predictor_with_other_model(
172169
1,
173170
"ml.m5.4xlarge",
174171
endpoint_name=endpoint_name,
175-
endpoint_type=EndpointType.GEN2,
172+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
176173
resources=resources,
177174
)
178175
xgboost_predictor.serializer = CSVSerializer()
@@ -208,7 +205,7 @@ def test_deploy_multi_models_without_endpoint_name(tfs_model, resources):
208205
tfs_predictor1 = tfs_model.deploy(
209206
1,
210207
"ml.m5.large",
211-
endpoint_type=EndpointType.GEN2,
208+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
212209
resources=resources,
213210
)
214211

@@ -221,7 +218,7 @@ def test_deploy_multi_models_without_endpoint_name(tfs_model, resources):
221218
1,
222219
"ml.m5.large",
223220
endpoint_name=endpoint_name,
224-
endpoint_type=EndpointType.GEN2,
221+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
225222
resources=resources,
226223
)
227224

tests/unit/sagemaker/jumpstart/model/test_model.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from tests.unit.sagemaker.jumpstart.utils import (
3535
get_special_model_spec,
3636
overwrite_dictionary,
37-
get_special_model_spec_for_gen2_endpoint,
37+
get_special_model_spec_for_inference_component_based_endpoint,
3838
)
3939

4040
execution_role = "fake role! do not use!"
@@ -125,7 +125,7 @@ def test_non_prepacked(
125125
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
126126
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
127127
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
128-
def test_non_prepacked_gen2_endpoint(
128+
def test_non_prepacked_inference_component_based_endpoint(
129129
self,
130130
mock_model_deploy: mock.Mock,
131131
mock_model_init: mock.Mock,
@@ -141,7 +141,9 @@ def test_non_prepacked_gen2_endpoint(
141141
mock_is_valid_model_id.return_value = True
142142
model_id, _ = "js-trainable-model", "*"
143143

144-
mock_get_model_specs.side_effect = get_special_model_spec_for_gen2_endpoint
144+
mock_get_model_specs.side_effect = (
145+
get_special_model_spec_for_inference_component_based_endpoint
146+
)
145147

146148
mock_session.return_value = sagemaker_session
147149

@@ -180,7 +182,7 @@ def test_non_prepacked_gen2_endpoint(
180182
resources=resource_requirements,
181183
)
182184

183-
model.deploy(endpoint_type=EndpointType.GEN2)
185+
model.deploy(endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED)
184186

185187
mock_model_deploy.assert_called_once_with(
186188
initial_instance_count=1,
@@ -193,7 +195,7 @@ def test_non_prepacked_gen2_endpoint(
193195
],
194196
endpoint_logging=False,
195197
resources=resource_requirements,
196-
endpoint_type=EndpointType.GEN2,
198+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
197199
)
198200

199201
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@@ -203,7 +205,7 @@ def test_non_prepacked_gen2_endpoint(
203205
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
204206
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
205207
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
206-
def test_non_prepacked_gen2_endpoint_no_default_pass_custom_resources(
208+
def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom_resources(
207209
self,
208210
mock_model_deploy: mock.Mock,
209211
mock_model_init: mock.Mock,
@@ -254,7 +256,7 @@ def test_non_prepacked_gen2_endpoint_no_default_pass_custom_resources(
254256
)
255257

256258
model.deploy(
257-
endpoint_type=EndpointType.GEN2,
259+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
258260
resources=custom_resource_requirements,
259261
)
260262

@@ -268,7 +270,7 @@ def test_non_prepacked_gen2_endpoint_no_default_pass_custom_resources(
268270
],
269271
endpoint_logging=False,
270272
resources=custom_resource_requirements,
271-
endpoint_type=EndpointType.GEN2,
273+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
272274
)
273275

274276
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")

tests/unit/sagemaker/jumpstart/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,16 @@ def get_special_model_spec(
112112
return specs
113113

114114

115-
def get_special_model_spec_for_gen2_endpoint(
115+
def get_special_model_spec_for_inference_component_based_endpoint(
116116
region: str = None,
117117
model_id: str = None,
118118
version: str = None,
119119
s3_client: boto3.client = None,
120120
) -> JumpStartModelSpecs:
121121
"""This function mocks cache accessor functions. For this mock,
122122
we only retrieve model specs based on the model ID and adding
123-
generation 2 endpoint specific specification. This is reserved
124-
for special specs.
123+
inference component based endpoint specific specification.
124+
This is reserved for special specs.
125125
"""
126126
model_spec_dict = SPECIAL_MODEL_SPECS_DICT[model_id]
127127
model_spec_dict["hosting_resource_requirements"] = {

tests/unit/sagemaker/model/test_deploy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def test_deploy_with_name_and_resources(sagemaker_session):
10091009
MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
10101010
)
10111011

1012-
endpoint_name = "Gen2-endpoint-test"
1012+
endpoint_name = "inference-component-based-endpoint-test"
10131013
model.deploy(
10141014
endpoint_name=endpoint_name,
10151015
instance_type=INSTANCE_TYPE,

tests/unit/sagemaker/model/test_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def test_script_mode_model_uses_jumpstart_base_name(repack_model, sagemaker_sess
858858

859859
@patch("sagemaker.utils.repack_model")
860860
@patch("sagemaker.fw_utils.tar_and_upload_dir")
861-
def test_all_framework_models_generation_two_endpoint_deploy_path(
861+
def test_all_framework_models_inference_component_based_endpoint_deploy_path(
862862
repack_model, tar_and_uload_dir, sagemaker_session
863863
):
864864
framework_model_classes_to_kwargs = {
@@ -893,7 +893,7 @@ def test_all_framework_models_generation_two_endpoint_deploy_path(
893893
).deploy(
894894
instance_type="ml.m2.xlarge",
895895
initial_instance_count=INSTANCE_COUNT,
896-
endpoint_type=EndpointType.GEN2,
896+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
897897
resources=ResourceRequirements(
898898
requests={
899899
"num_accelerators": 1,
@@ -904,7 +904,7 @@ def test_all_framework_models_generation_two_endpoint_deploy_path(
904904
),
905905
)
906906

907-
# Verified Generation2 endpoint and inference component creation
907+
# Verified inference component based endpoint and inference component creation
908908
# path
909909
sagemaker_session.endpoint_in_service_or_not.assert_called_once()
910910
sagemaker_session.create_model.assert_called_once()

0 commit comments

Comments
 (0)