Skip to content

Commit 0262cdc

Browse files
committed
fix: unit test for matching args
1 parent 685c8df commit 0262cdc

File tree

5 files changed

+52
-15
lines changed

5 files changed

+52
-15
lines changed

src/sagemaker/jumpstart/estimator.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def deploy(
663663
role: Optional[str] = None,
664664
predictor_cls: Optional[callable] = None,
665665
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
666-
name: Optional[str] = None,
666+
model_name: Optional[str] = None,
667667
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
668668
sagemaker_session: Optional[session.Session] = None,
669669
enable_network_isolation: Union[bool, PipelineVariable] = None,
@@ -675,6 +675,7 @@ def deploy(
675675
container_log_level: Optional[Union[int, PipelineVariable]] = None,
676676
dependencies: Optional[List[str]] = None,
677677
git_config: Optional[Dict[str, str]] = None,
678+
use_compiled_model: bool = False,
678679
) -> PredictorBase:
679680
"""Creates endpoint from training job.
680681
@@ -766,7 +767,7 @@ def deploy(
766767
function on the created endpoint name. (Default: None).
767768
env (Optional[dict[str, str] or dict[str, PipelineVariable]]): Environment variables
768769
to run with ``image_uri`` when hosted in SageMaker. (Default: None).
769-
name (Optional[str]): The model name. If None, a default model name will be
770+
model_name (Optional[str]): The model name. If None, a default model name will be
770771
selected on each ``deploy``. (Default: None).
771772
vpc_config (Optional[Union[dict[str, list[str]],dict[str, list[PipelineVariable]]]]):
772773
The VpcConfig set on the model (Default: None)
@@ -909,6 +910,8 @@ def deploy(
909910
the SageMaker Python SDK attempts to use either the CodeCommit
910911
credential helper or local credential storage for authentication.
911912
(Default: None).
913+
use_compiled_model (bool): Flag to select whether to use compiled
914+
(optimized) model. (Default: False).
912915
"""
913916

914917
self.orig_predictor_cls = predictor_cls
@@ -948,7 +951,7 @@ def deploy(
948951
role=role,
949952
predictor_cls=predictor_cls,
950953
env=env,
951-
name=name,
954+
model_name=model_name,
952955
vpc_config=vpc_config,
953956
sagemaker_session=sagemaker_session,
954957
enable_network_isolation=enable_network_isolation,
@@ -960,6 +963,7 @@ def deploy(
960963
container_log_level=container_log_level,
961964
dependencies=dependencies,
962965
git_config=git_config,
966+
use_compiled_model=use_compiled_model,
963967
)
964968

965969
predictor = super(JumpStartEstimator, self).deploy(

src/sagemaker/jumpstart/factory/estimator.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ def get_deploy_kwargs(
247247
role: Optional[str] = None,
248248
predictor_cls: Optional[callable] = None,
249249
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
250-
name: Optional[str] = None,
251250
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
252251
sagemaker_session: Optional[Session] = None,
253252
enable_network_isolation: Union[bool, PipelineVariable] = None,
@@ -261,6 +260,8 @@ def get_deploy_kwargs(
261260
git_config: Optional[Dict[str, str]] = None,
262261
tolerate_deprecated_model: Optional[bool] = None,
263262
tolerate_vulnerable_model: Optional[bool] = None,
263+
use_compiled_model: Optional[bool] = None,
264+
model_name: Optional[str] = None,
264265
) -> JumpStartEstimatorDeployKwargs:
265266
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
266267

@@ -301,7 +302,7 @@ def get_deploy_kwargs(
301302
env=env,
302303
predictor_cls=predictor_cls,
303304
role=role,
304-
name=name,
305+
name=model_name,
305306
vpc_config=vpc_config,
306307
sagemaker_session=sagemaker_session,
307308
enable_network_isolation=enable_network_isolation,
@@ -344,7 +345,7 @@ def get_deploy_kwargs(
344345
inference_recommendation_id=model_deploy_kwargs.inference_recommendation_id,
345346
explainer_config=model_deploy_kwargs.explainer_config,
346347
role=model_init_kwargs.role,
347-
name=model_init_kwargs.name,
348+
model_name=model_init_kwargs.name,
348349
vpc_config=model_init_kwargs.vpc_config,
349350
sagemaker_session=model_init_kwargs.sagemaker_session,
350351
enable_network_isolation=model_init_kwargs.enable_network_isolation,
@@ -356,6 +357,7 @@ def get_deploy_kwargs(
356357
git_config=model_init_kwargs.git_config,
357358
tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model,
358359
tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model,
360+
use_compiled_model=use_compiled_model,
359361
)
360362

361363
return estimator_deploy_kwargs

src/sagemaker/jumpstart/types.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,6 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
974974
"inference_recommendation_id",
975975
"explainer_config",
976976
"role",
977-
"name",
978977
"vpc_config",
979978
"sagemaker_session",
980979
"enable_network_isolation",
@@ -986,6 +985,8 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
986985
"git_config",
987986
"tolerate_deprecated_model",
988987
"tolerate_vulnerable_model",
988+
"model_name",
989+
"use_compiled_model",
989990
]
990991

991992
SERIALIZATION_EXCLUSION_SET = {
@@ -1023,7 +1024,7 @@ def __init__(
10231024
role: Optional[str] = None,
10241025
predictor_cls: Optional[callable] = None,
10251026
env: Optional[Dict[str, Union[str, Any]]] = None,
1026-
name: Optional[str] = None,
1027+
model_name: Optional[str] = None,
10271028
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
10281029
sagemaker_session: Optional[Any] = None,
10291030
enable_network_isolation: Union[bool, Any] = None,
@@ -1037,6 +1038,7 @@ def __init__(
10371038
git_config: Optional[Dict[str, str]] = None,
10381039
tolerate_deprecated_model: Optional[bool] = None,
10391040
tolerate_vulnerable_model: Optional[bool] = None,
1041+
use_compiled_model: bool = False,
10401042
) -> None:
10411043
"""Instantiates JumpStartEstimatorInitKwargs object."""
10421044

@@ -1066,7 +1068,7 @@ def __init__(
10661068
self.inference_recommendation_id = inference_recommendation_id
10671069
self.explainer_config = explainer_config
10681070
self.role = role
1069-
self.name = name
1071+
self.model_name = model_name
10701072
self.vpc_config = vpc_config
10711073
self.sagemaker_session = sagemaker_session
10721074
self.enable_network_isolation = enable_network_isolation
@@ -1078,3 +1080,4 @@ def __init__(
10781080
self.git_config = git_config
10791081
self.tolerate_deprecated_model = tolerate_deprecated_model
10801082
self.tolerate_vulnerable_model = tolerate_vulnerable_model
1083+
self.use_compiled_model = use_compiled_model

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import time
15-
from typing import Optional
15+
from typing import Optional, Set
1616
from unittest import mock
1717
import unittest
1818
from inspect import signature
@@ -142,8 +142,9 @@ def test_non_prepacked(
142142
predictor_cls=Predictor,
143143
role=execution_role,
144144
wait=True,
145+
use_compiled_model=False,
145146
enable_network_isolation=False,
146-
name="blahblahblah-9876",
147+
model_name="blahblahblah-9876",
147148
endpoint_name="blahblahblah-9876",
148149
)
149150

@@ -234,6 +235,7 @@ def test_prepacked(
234235
predictor_cls=Predictor,
235236
role=execution_role,
236237
wait=True,
238+
use_compiled_model=False,
237239
enable_network_isolation=False,
238240
)
239241

@@ -503,7 +505,8 @@ def evaluate_estimator_workflow_with_kwargs(
503505
"predictor_cls": Predictor,
504506
"role": init_kwargs["role"],
505507
"enable_network_isolation": False,
506-
"name": "blahblahblah-1234",
508+
"use_compiled_model": False,
509+
"model_name": "blahblahblah-1234",
507510
"endpoint_name": "blahblahblah-1234",
508511
},
509512
deploy_kwargs,
@@ -512,6 +515,15 @@ def evaluate_estimator_workflow_with_kwargs(
512515
mock_estimator_deploy.assert_called_once_with(**expected_deploy_kwargs)
513516

514517
def test_jumpstart_estimator_kwargs_match_parent_class(self):
518+
519+
"""If you add arguments to <Estimator constructor>, this test will fail.
520+
Please add the new argument to the skip set below,
521+
and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
522+
523+
init_args_to_skip: Set[str] = set(["kwargs"])
524+
fit_args_to_skip: Set[str] = set()
525+
deploy_args_to_skip: Set[str] = set(["kwargs"])
526+
515527
parent_class_init = Estimator.__init__
516528
parent_class_init_args = set(signature(parent_class_init).parameters.keys())
517529

@@ -525,6 +537,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
525537
"tolerate_vulnerable_model",
526538
"tolerate_deprecated_model",
527539
}
540+
assert parent_class_init_args - js_class_init_args == init_args_to_skip
528541

529542
parent_class_fit = Estimator.fit
530543
parent_class_fit_args = set(signature(parent_class_fit).parameters.keys())
@@ -533,6 +546,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
533546
js_class_fit_args = set(signature(js_class_fit).parameters.keys())
534547

535548
assert js_class_fit_args - parent_class_fit_args == set()
549+
assert parent_class_fit_args - js_class_fit_args == fit_args_to_skip
536550

537551
model_class_init = Model.__init__
538552
model_class_init_args = set(signature(model_class_init).parameters.keys())
@@ -546,7 +560,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
546560
assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - {
547561
"model_data",
548562
"self",
563+
"name",
549564
}
565+
assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
550566

551567
@mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs")
552568
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
@@ -842,9 +858,10 @@ def test_training_passes_role_to_deploy(
842858
},
843859
predictor_cls=Predictor,
844860
wait=True,
861+
use_compiled_model=False,
845862
role=mock_role,
846863
enable_network_isolation=False,
847-
name="blahblahblah-3456",
864+
model_name="blahblahblah-3456",
848865
endpoint_name="blahblahblah-3456",
849866
)
850867

@@ -916,9 +933,10 @@ def test_training_passes_session_to_deploy(
916933
},
917934
predictor_cls=Predictor,
918935
wait=True,
936+
use_compiled_model=False,
919937
role=mock_role,
920938
enable_network_isolation=False,
921-
name="blahblahblah-3456",
939+
model_name="blahblahblah-3456",
922940
endpoint_name="blahblahblah-3456",
923941
)
924942

tests/unit/sagemaker/jumpstart/model/test_model.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
from inspect import signature
15-
from typing import Optional
15+
from typing import Optional, Set
1616
from unittest import mock
1717
import unittest
1818
import pytest
@@ -322,6 +322,14 @@ def evaluate_model_workflow_with_kwargs(
322322
mock_model_deploy.assert_called_once_with(**expected_deploy_kwargs)
323323

324324
def test_jumpstart_model_kwargs_match_parent_class(self):
325+
326+
"""If you add arguments to <Model constructor>, this test will fail.
327+
Please add the new argument to the skip set below,
328+
and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
329+
330+
init_args_to_skip: Set[str] = set()
331+
deploy_args_to_skip: Set[str] = set(["kwargs"])
332+
325333
parent_class_init = Model.__init__
326334
parent_class_init_args = set(signature(parent_class_init).parameters.keys())
327335

@@ -336,6 +344,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
336344
"tolerate_deprecated_model",
337345
"instance_type",
338346
}
347+
assert parent_class_init_args - js_class_init_args == init_args_to_skip
339348

340349
parent_class_deploy = Model.deploy
341350
parent_class_deploy_args = set(signature(parent_class_deploy).parameters.keys())
@@ -344,6 +353,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
344353
js_class_deploy_args = set(signature(js_class_deploy).parameters.keys())
345354

346355
assert js_class_deploy_args - parent_class_deploy_args == set()
356+
assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
347357

348358
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
349359
@mock.patch("sagemaker.jumpstart.model.Model.__init__")

0 commit comments

Comments
 (0)