diff --git a/setup.py b/setup.py index fd553ae76d..ddb37d767d 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def read_version(): # Declare minimal set for installation required_packages = [ - "boto3>=1.13.24", + "boto3>=1.14.12", "numpy>=1.9.0", "protobuf>=3.1", "scipy>=0.19.0", diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ac51f5a37a..d00f2a805a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1996,6 +1996,7 @@ def transform( experiment_config, tags, data_processing, + model_client_config=None, ): """Create an Amazon SageMaker transform job. @@ -2020,6 +2021,9 @@ def transform( data_processing(dict): A dictionary describing config for combining the input data and transformed data. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + model_client_config (dict): A dictionary describing the model configuration for the + job. Dictionary contains two optional keys, + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. """ transform_request = { "TransformJobName": job_name, @@ -2050,6 +2054,9 @@ def transform( if experiment_config and len(experiment_config) > 0: transform_request["ExperimentConfig"] = experiment_config + if model_client_config and len(model_client_config) > 0: + transform_request["ModelClientConfig"] = model_client_config + LOGGER.info("Creating transform job with name: %s", job_name) LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4)) self.sagemaker_client.create_transform_job(**transform_request) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 95cbb2e634..cdbdf882c9 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -120,6 +120,7 @@ def transform( output_filter=None, join_source=None, experiment_config=None, + model_client_config=None, wait=False, logs=False, ): @@ -172,6 +173,10 @@ def transform( Dictionary contains three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. (default: ``None``). + model_client_config (dict[str, str]): Model configuration. + Dictionary contains two optional keys, + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + (default: ``None``). wait (bool): Whether the call should wait until the job completes (default: False). logs (bool): Whether to show the logs produced by the job. @@ -208,6 +213,7 @@ def transform( output_filter, join_source, experiment_config, + model_client_config, ) if wait: @@ -342,6 +348,7 @@ def start_new( output_filter, join_source, experiment_config, + model_client_config, ): """ Args: @@ -355,6 +362,7 @@ def start_new( output_filter: join_source: experiment_config: + model_client_config: """ config = _TransformJob._load_config( data, data_type, content_type, compression_type, split_type, transformer @@ -374,6 +382,7 @@ def start_new( output_config=config["output_config"], resource_config=config["resource_config"], experiment_config=experiment_config, + model_client_config=model_client_config, tags=transformer.tags, data_processing=data_processing, ) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 6f9b56bc09..9f7a687836 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -230,6 +230,26 @@ def test_transform_mxnet_tags( assert tags == model_tags +def test_transform_model_client_config( + mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type +): + model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2} + transformer = mxnet_estimator.transformer(1, cpu_instance_type) + transformer.transform( + mxnet_transform_input, content_type="text/csv", model_client_config=model_client_config + ) + + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): + transformer.wait() + transform_job_desc = sagemaker_session.sagemaker_client.describe_transform_job( + TransformJobName=transformer.latest_transform_job.name + ) + + assert model_client_config == transform_job_desc["ModelClientConfig"] + + def test_transform_byo_estimator(sagemaker_session, cpu_instance_type): data_path = os.path.join(DATA_DIR, "one_p_mnist") pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 103a74af54..f6c0d84c19 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -576,6 +576,7 @@ def test_s3_input_all_arguments(): "TrialName": "dummyT", "TrialComponentDisplayName": "dummyTC", } +MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60} DEFAULT_EXPECTED_TRAIN_JOB_ARGS = { "OutputDataConfig": {"S3OutputPath": S3_OUTPUT}, @@ -1258,6 +1259,7 @@ def test_transform_pack_to_request(sagemaker_session): output_config=out_config, resource_config=resource_config, experiment_config=None, + model_client_config=None, tags=None, data_processing=data_processing, ) @@ -1283,6 +1285,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): output_config={}, resource_config={}, experiment_config=EXPERIMENT_CONFIG, + model_client_config=MODEL_CLIENT_CONFIG, tags=TAGS, data_processing=None, ) @@ -1294,6 +1297,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): assert actual_args["Environment"] == env assert actual_args["Tags"] == TAGS assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG + assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG @patch("sys.stdout", new_callable=io.BytesIO if six.PY2 else io.StringIO) diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index aeaa290827..a929a82269 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -167,6 +167,7 @@ def test_transform_with_all_params(start_new_job, transformer): "TrialName": "t", "TrialComponentDisplayName": "tc", } + model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2} transformer.transform( DATA, @@ -179,6 +180,7 @@ def test_transform_with_all_params(start_new_job, transformer): output_filter=output_filter, join_source=join_source, experiment_config=experiment_config, + model_client_config=model_client_config, ) assert transformer._current_job_name == JOB_NAME @@ -194,6 +196,7 @@ def test_transform_with_all_params(start_new_job, transformer): output_filter, join_source, experiment_config, + model_client_config, ) @@ -428,6 +431,8 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): split_type = "Line" io_filter = "$" join_source = "Input" + model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2} + job = _TransformJob.start_new( transformer=transformer, data=DATA, @@ -439,6 +444,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): output_filter=io_filter, join_source=join_source, experiment_config={"ExperimentName": "exp"}, + model_client_config=model_client_config, ) assert job.sagemaker_session == sagemaker_session @@ -460,6 +466,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): output_config=output_config, resource_config=resource_config, experiment_config={"ExperimentName": "exp"}, + model_client_config=model_client_config, tags=tags, data_processing=prepare_data_processing.return_value, )