Skip to content

Commit 77acc63

Browse files
Merge branch 'master' into fix-spark-processor
2 parents 6681603 + c0929cc commit 77acc63

File tree

5 files changed

+47
-47
lines changed

5 files changed

+47
-47
lines changed

src/sagemaker/utils.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -759,32 +759,55 @@ def update_container_with_inference_params(
759759
dict: dict with inference recommender params
760760
"""
761761

762-
if (
763-
framework is not None
764-
and framework_version is not None
765-
and nearest_model_name is not None
766-
and data_input_configuration is not None
767-
):
762+
if framework is not None and framework_version is not None and nearest_model_name is not None:
768763
if container_list is not None:
769764
for obj in container_list:
770-
obj.update(
771-
{
772-
"Framework": framework,
773-
"FrameworkVersion": framework_version,
774-
"NearestModelName": nearest_model_name,
775-
"ModelInput": {
776-
"DataInputConfig": data_input_configuration,
777-
},
778-
}
765+
construct_container_object(
766+
obj, data_input_configuration, framework, framework_version, nearest_model_name
779767
)
780768
if container_obj is not None:
781-
container_obj.update(
782-
{
783-
"Framework": framework,
784-
"FrameworkVersion": framework_version,
785-
"NearestModelName": nearest_model_name,
786-
"ModelInput": {
787-
"DataInputConfig": data_input_configuration,
788-
},
789-
}
769+
construct_container_object(
770+
container_obj,
771+
data_input_configuration,
772+
framework,
773+
framework_version,
774+
nearest_model_name,
790775
)
776+
777+
778+
def construct_container_object(
779+
obj, data_input_configuration, framework, framework_version, nearest_model_name
780+
):
781+
"""Function to construct container object.
782+
783+
Args:
784+
framework (str): Machine learning framework of the model package container image
785+
(default: None).
786+
framework_version (str): Framework version of the Model Package Container Image
787+
(default: None).
788+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
789+
Amazon SageMaker Inference Recommender (default: None).
790+
data_input_configuration (str): Input object for the model (default: None).
791+
container_obj (dict): object to be updated.
792+
container_list (list): list to be updated.
793+
794+
Returns:
795+
dict: container object
796+
"""
797+
798+
obj.update(
799+
{
800+
"Framework": framework,
801+
"FrameworkVersion": framework_version,
802+
"NearestModelName": nearest_model_name,
803+
}
804+
)
805+
806+
if data_input_configuration is not None:
807+
obj.update(
808+
{
809+
"ModelInput": {
810+
"DataInputConfig": data_input_configuration,
811+
},
812+
}
813+
)

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def test_conditional_pytorch_training_model_registration(
9999
framework = "TENSORFLOW"
100100
framework_version = "2.9"
101101
nearest_model_name = "resnet50"
102-
data_input_configuration = '{"input_1":[1,224,224,3]}'
103102

104103
# If image_uri is not provided, the instance_type should not be a pipeline variable
105104
# since instance_type is used to retrieve image_uri in compile time (PySDK)
@@ -132,7 +131,6 @@ def test_conditional_pytorch_training_model_registration(
132131
framework=framework,
133132
framework_version=framework_version,
134133
nearest_model_name=nearest_model_name,
135-
data_input_configuration=data_input_configuration,
136134
)
137135

138136
model = Model(
@@ -219,7 +217,6 @@ def test_mxnet_model_registration(
219217
framework = "TENSORFLOW"
220218
framework_version = "2.9"
221219
nearest_model_name = "resnet50"
222-
data_input_configuration = '{"input_1":[1,224,224,3]}'
223220

224221
model = MXNetModel(
225222
entry_point=entry_point,
@@ -244,7 +241,6 @@ def test_mxnet_model_registration(
244241
framework=framework,
245242
framework_version=framework_version,
246243
nearest_model_name=nearest_model_name,
247-
data_input_configuration=data_input_configuration,
248244
)
249245

250246
pipeline = Pipeline(
@@ -293,7 +289,6 @@ def test_sklearn_xgboost_sip_model_registration(
293289
framework = "TENSORFLOW"
294290
framework_version = "2.9"
295291
nearest_model_name = "resnet50"
296-
data_input_configuration = '{"input_1":[1,224,224,3]}'
297292

298293
# The instance_type should not be a pipeline variable
299294
# since it is used to retrieve image_uri in compile time (PySDK)
@@ -450,7 +445,6 @@ def test_sklearn_xgboost_sip_model_registration(
450445
framework=framework,
451446
framework_version=framework_version,
452447
nearest_model_name=nearest_model_name,
453-
data_input_configuration=data_input_configuration,
454448
)
455449

456450
pipeline = Pipeline(

tests/unit/sagemaker/workflow/test_pipeline_session.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
199199
framework="TENSORFLOW",
200200
framework_version="2.9",
201201
nearest_model_name="resnet50",
202-
data_input_configuration='{"input_1":[1,224,224,3]}',
203202
)
204203

205204
expected_output = {
@@ -221,9 +220,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
221220
"Framework": "TENSORFLOW",
222221
"FrameworkVersion": "2.9",
223222
"NearestModelName": "resnet50",
224-
"ModelInput": {
225-
"DataInputConfig": '{"input_1":[1,224,224,3]}',
226-
},
227223
}
228224
],
229225
"SupportedContentTypes": ["text/csv"],

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
446446
framework="TENSORFLOW",
447447
framework_version="2.9",
448448
nearest_model_name="resnet50",
449-
data_input_configuration='{"input_1":[1,224,224,3]}',
450449
)
451450
assert ordered(register_model.request_dicts()) == ordered(
452451
[
@@ -523,7 +522,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
523522
framework="TENSORFLOW",
524523
framework_version="2.9",
525524
nearest_model_name="resnet50",
526-
data_input_configuration='{"input_1":[1,224,224,3]}',
527525
)
528526
assert ordered(register_model.request_dicts()) == ordered(
529527
[
@@ -542,9 +540,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
542540
"Framework": "TENSORFLOW",
543541
"FrameworkVersion": "2.9",
544542
"NearestModelName": "resnet50",
545-
"ModelInput": {
546-
"DataInputConfig": '{"input_1":[1,224,224,3]}',
547-
},
548543
},
549544
{
550545
"Image": "fakeimage2",
@@ -553,9 +548,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
553548
"Framework": "TENSORFLOW",
554549
"FrameworkVersion": "2.9",
555550
"NearestModelName": "resnet50",
556-
"ModelInput": {
557-
"DataInputConfig": '{"input_1":[1,224,224,3]}',
558-
},
559551
},
560552
],
561553
"SupportedContentTypes": ["content_type"],
@@ -619,7 +611,6 @@ def test_register_model_with_model_repack_with_estimator(
619611
framework="TENSORFLOW",
620612
framework_version="2.9",
621613
nearest_model_name="resnet50",
622-
data_input_configuration='{"input_1":[1,224,224,3]}',
623614
)
624615

625616
request_dicts = register_model.request_dicts()

tests/unit/test_estimator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,7 +3260,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
32603260
framework = "TENSORFLOW"
32613261
framework_version = "2.9"
32623262
nearest_model_name = "resnet50"
3263-
data_input_config = '{"input_1":[1,224,224,3]}'
32643263

32653264
estimator.register(
32663265
content_types=content_types,
@@ -3271,7 +3270,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
32713270
framework=framework,
32723271
framework_version=framework_version,
32733272
nearest_model_name=nearest_model_name,
3274-
data_input_configuration=data_input_config,
32753273
)
32763274
sagemaker_session.create_model.assert_not_called()
32773275

@@ -3319,7 +3317,6 @@ def test_register_inference_image(sagemaker_session):
33193317
framework = "TENSORFLOW"
33203318
framework_version = "2.9"
33213319
nearest_model_name = "resnet50"
3322-
data_input_config = '{"input_1":[1,224,224,3]}'
33233320

33243321
estimator.register(
33253322
content_types=content_types,
@@ -3333,7 +3330,6 @@ def test_register_inference_image(sagemaker_session):
33333330
framework=framework,
33343331
framework_version=framework_version,
33353332
nearest_model_name=nearest_model_name,
3336-
data_input_configuration=data_input_config,
33373333
)
33383334
sagemaker_session.create_model.assert_not_called()
33393335

0 commit comments

Comments
 (0)