diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index ed5b3c5e75..32cad0dc12 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -749,32 +749,55 @@ def update_container_with_inference_params( dict: dict with inference recommender params """ - if ( - framework is not None - and framework_version is not None - and nearest_model_name is not None - and data_input_configuration is not None - ): + if framework is not None and framework_version is not None and nearest_model_name is not None: if container_list is not None: for obj in container_list: - obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } + construct_container_object( + obj, data_input_configuration, framework, framework_version, nearest_model_name ) if container_obj is not None: - container_obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } + construct_container_object( + container_obj, + data_input_configuration, + framework, + framework_version, + nearest_model_name, ) + + +def construct_container_object( + obj, data_input_configuration, framework, framework_version, nearest_model_name +): + """Function to construct container object. + + Args: + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + container_obj (dict): object to be updated. + container_list (list): list to be updated. + + Returns: + dict: container object + """ + + obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + } + ) + + if data_input_configuration is not None: + obj.update( + { + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index d0f617a266..56611fb696 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -99,7 +99,6 @@ def test_conditional_pytorch_training_model_registration( framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_configuration = '{"input_1":[1,224,224,3]}' # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -132,7 +131,6 @@ def test_conditional_pytorch_training_model_registration( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_configuration, ) model = Model( @@ -219,7 +217,6 @@ def test_mxnet_model_registration( framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_configuration = '{"input_1":[1,224,224,3]}' model = MXNetModel( entry_point=entry_point, @@ -244,7 +241,6 @@ def test_mxnet_model_registration( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -293,7 +289,6 @@ def test_sklearn_xgboost_sip_model_registration( framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_configuration = '{"input_1":[1,224,224,3]}' # The instance_type should not be a pipeline variable # since it is used to retrieve image_uri in compile time (PySDK) @@ -450,7 +445,6 @@ def test_sklearn_xgboost_sip_model_registration( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_configuration, ) pipeline = Pipeline( diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index 90a9116c07..fd05992a9f 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -151,7 +151,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types( framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) expected_output = { @@ -173,9 +172,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types( "Framework": "TENSORFLOW", "FrameworkVersion": "2.9", "NearestModelName": "resnet50", - "ModelInput": { - "DataInputConfig": '{"input_1":[1,224,224,3]}', - }, } ], "SupportedContentTypes": ["text/csv"], diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 4aa55fd068..fd84bf4b77 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -446,7 +446,6 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -523,7 +522,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -542,9 +540,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Framework": "TENSORFLOW", "FrameworkVersion": "2.9", "NearestModelName": "resnet50", - "ModelInput": { - "DataInputConfig": '{"input_1":[1,224,224,3]}', - }, }, { "Image": "fakeimage2", @@ -553,9 +548,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Framework": "TENSORFLOW", "FrameworkVersion": "2.9", "NearestModelName": "resnet50", - "ModelInput": { - "DataInputConfig": '{"input_1":[1,224,224,3]}', - }, }, ], "SupportedContentTypes": ["content_type"], @@ -619,7 +611,6 @@ def test_register_model_with_model_repack_with_estimator( framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) request_dicts = register_model.request_dicts() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index d402a509fc..859cdb941f 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3260,7 +3260,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3271,7 +3270,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3319,7 +3317,6 @@ def test_register_inference_image(sagemaker_session): framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3333,7 +3330,6 @@ def test_register_inference_image(sagemaker_session): framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called()