Skip to content

Commit 1f3fbe7

Browse files
committed
fix: make 'ModelInput' field optional for inference recommendation
1 parent ef39168 commit 1f3fbe7

File tree

5 files changed

+49
-47
lines changed

5 files changed

+49
-47
lines changed

src/sagemaker/utils.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -749,32 +749,57 @@ def update_container_with_inference_params(
749749
dict: dict with inference recommender params
750750
"""
751751

752-
if (
753-
framework is not None
754-
and framework_version is not None
755-
and nearest_model_name is not None
756-
and data_input_configuration is not None
757-
):
752+
if framework is not None and framework_version is not None and nearest_model_name is not None:
758753
if container_list is not None:
759754
for obj in container_list:
760-
obj.update(
761-
{
762-
"Framework": framework,
763-
"FrameworkVersion": framework_version,
764-
"NearestModelName": nearest_model_name,
765-
"ModelInput": {
766-
"DataInputConfig": data_input_configuration,
767-
},
768-
}
755+
construct_container_object(
756+
obj, data_input_configuration, framework, framework_version, nearest_model_name
769757
)
770758
if container_obj is not None:
771-
container_obj.update(
772-
{
773-
"Framework": framework,
774-
"FrameworkVersion": framework_version,
775-
"NearestModelName": nearest_model_name,
776-
"ModelInput": {
777-
"DataInputConfig": data_input_configuration,
778-
},
779-
}
759+
construct_container_object(
760+
container_obj,
761+
data_input_configuration,
762+
framework,
763+
framework_version,
764+
nearest_model_name,
780765
)
766+
767+
768+
def construct_container_object(
769+
obj, data_input_configuration, framework, framework_version, nearest_model_name
770+
):
771+
"""Function to construct container object.
772+
773+
Args:
774+
framework (str): Machine learning framework of the model package container image
775+
(default: None).
776+
framework_version (str): Framework version of the Model Package Container Image
777+
(default: None).
778+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
779+
Amazon SageMaker Inference Recommender (default: None).
780+
data_input_configuration (str): Input object for the model (default: None).
781+
container_obj (dict): object to be updated.
782+
container_list (list): list to be updated.
783+
784+
Returns:
785+
dict: container object
786+
"""
787+
if data_input_configuration is not None:
788+
obj.update(
789+
{
790+
"Framework": framework,
791+
"FrameworkVersion": framework_version,
792+
"NearestModelName": nearest_model_name,
793+
"ModelInput": {
794+
"DataInputConfig": data_input_configuration,
795+
},
796+
}
797+
)
798+
else:
799+
obj.update(
800+
{
801+
"Framework": framework,
802+
"FrameworkVersion": framework_version,
803+
"NearestModelName": nearest_model_name,
804+
}
805+
)

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def test_conditional_pytorch_training_model_registration(
9999
framework = "TENSORFLOW"
100100
framework_version = "2.9"
101101
nearest_model_name = "resnet50"
102-
data_input_configuration = '{"input_1":[1,224,224,3]}'
103102

104103
# If image_uri is not provided, the instance_type should not be a pipeline variable
105104
# since instance_type is used to retrieve image_uri in compile time (PySDK)
@@ -132,7 +131,6 @@ def test_conditional_pytorch_training_model_registration(
132131
framework=framework,
133132
framework_version=framework_version,
134133
nearest_model_name=nearest_model_name,
135-
data_input_configuration=data_input_configuration,
136134
)
137135

138136
model = Model(
@@ -219,7 +217,6 @@ def test_mxnet_model_registration(
219217
framework = "TENSORFLOW"
220218
framework_version = "2.9"
221219
nearest_model_name = "resnet50"
222-
data_input_configuration = '{"input_1":[1,224,224,3]}'
223220

224221
model = MXNetModel(
225222
entry_point=entry_point,
@@ -244,7 +241,6 @@ def test_mxnet_model_registration(
244241
framework=framework,
245242
framework_version=framework_version,
246243
nearest_model_name=nearest_model_name,
247-
data_input_configuration=data_input_configuration,
248244
)
249245

250246
pipeline = Pipeline(
@@ -293,7 +289,6 @@ def test_sklearn_xgboost_sip_model_registration(
293289
framework = "TENSORFLOW"
294290
framework_version = "2.9"
295291
nearest_model_name = "resnet50"
296-
data_input_configuration = '{"input_1":[1,224,224,3]}'
297292

298293
# The instance_type should not be a pipeline variable
299294
# since it is used to retrieve image_uri in compile time (PySDK)
@@ -450,7 +445,6 @@ def test_sklearn_xgboost_sip_model_registration(
450445
framework=framework,
451446
framework_version=framework_version,
452447
nearest_model_name=nearest_model_name,
453-
data_input_configuration=data_input_configuration,
454448
)
455449

456450
pipeline = Pipeline(

tests/unit/sagemaker/workflow/test_pipeline_session.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
151151
framework="TENSORFLOW",
152152
framework_version="2.9",
153153
nearest_model_name="resnet50",
154-
data_input_configuration='{"input_1":[1,224,224,3]}',
155154
)
156155

157156
expected_output = {
@@ -173,9 +172,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types(
173172
"Framework": "TENSORFLOW",
174173
"FrameworkVersion": "2.9",
175174
"NearestModelName": "resnet50",
176-
"ModelInput": {
177-
"DataInputConfig": '{"input_1":[1,224,224,3]}',
178-
},
179175
}
180176
],
181177
"SupportedContentTypes": ["text/csv"],

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
446446
framework="TENSORFLOW",
447447
framework_version="2.9",
448448
nearest_model_name="resnet50",
449-
data_input_configuration='{"input_1":[1,224,224,3]}',
450449
)
451450
assert ordered(register_model.request_dicts()) == ordered(
452451
[
@@ -523,7 +522,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
523522
framework="TENSORFLOW",
524523
framework_version="2.9",
525524
nearest_model_name="resnet50",
526-
data_input_configuration='{"input_1":[1,224,224,3]}',
527525
)
528526
assert ordered(register_model.request_dicts()) == ordered(
529527
[
@@ -542,9 +540,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
542540
"Framework": "TENSORFLOW",
543541
"FrameworkVersion": "2.9",
544542
"NearestModelName": "resnet50",
545-
"ModelInput": {
546-
"DataInputConfig": '{"input_1":[1,224,224,3]}',
547-
},
548543
},
549544
{
550545
"Image": "fakeimage2",
@@ -553,9 +548,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
553548
"Framework": "TENSORFLOW",
554549
"FrameworkVersion": "2.9",
555550
"NearestModelName": "resnet50",
556-
"ModelInput": {
557-
"DataInputConfig": '{"input_1":[1,224,224,3]}',
558-
},
559551
},
560552
],
561553
"SupportedContentTypes": ["content_type"],
@@ -619,7 +611,6 @@ def test_register_model_with_model_repack_with_estimator(
619611
framework="TENSORFLOW",
620612
framework_version="2.9",
621613
nearest_model_name="resnet50",
622-
data_input_configuration='{"input_1":[1,224,224,3]}',
623614
)
624615

625616
request_dicts = register_model.request_dicts()

tests/unit/test_estimator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,7 +3260,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
32603260
framework = "TENSORFLOW"
32613261
framework_version = "2.9"
32623262
nearest_model_name = "resnet50"
3263-
data_input_config = '{"input_1":[1,224,224,3]}'
32643263

32653264
estimator.register(
32663265
content_types=content_types,
@@ -3271,7 +3270,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
32713270
framework=framework,
32723271
framework_version=framework_version,
32733272
nearest_model_name=nearest_model_name,
3274-
data_input_configuration=data_input_config,
32753273
)
32763274
sagemaker_session.create_model.assert_not_called()
32773275

@@ -3319,7 +3317,6 @@ def test_register_inference_image(sagemaker_session):
33193317
framework = "TENSORFLOW"
33203318
framework_version = "2.9"
33213319
nearest_model_name = "resnet50"
3322-
data_input_config = '{"input_1":[1,224,224,3]}'
33233320

33243321
estimator.register(
33253322
content_types=content_types,
@@ -3333,7 +3330,6 @@ def test_register_inference_image(sagemaker_session):
33333330
framework=framework,
33343331
framework_version=framework_version,
33353332
nearest_model_name=nearest_model_name,
3336-
data_input_configuration=data_input_config,
33373333
)
33383334
sagemaker_session.create_model.assert_not_called()
33393335

0 commit comments

Comments
 (0)