Skip to content

feature: Add ModelClientConfig Fields for Batch Transform #1523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 9, 2020
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def read_version():

# Declare minimal set for installation
required_packages = [
"boto3>=1.13.24",
"boto3>=1.14.12",
"numpy>=1.9.0",
"protobuf>=3.1",
"scipy>=0.19.0",
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,7 @@ def transform(
experiment_config,
tags,
data_processing,
model_client_config=None,
):
"""Create an Amazon SageMaker transform job.

Expand All @@ -2020,6 +2021,9 @@ def transform(
data_processing(dict): A dictionary describing config for combining the input data and
transformed data. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
model_client_config (dict): A dictionary describing the model configuration for the
job. Dictionary contains two optional keys,
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
"""
transform_request = {
"TransformJobName": job_name,
Expand Down Expand Up @@ -2050,6 +2054,9 @@ def transform(
if experiment_config and len(experiment_config) > 0:
transform_request["ExperimentConfig"] = experiment_config

if model_client_config and len(model_client_config) > 0:
transform_request["ModelClientConfig"] = model_client_config

LOGGER.info("Creating transform job with name: %s", job_name)
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
self.sagemaker_client.create_transform_job(**transform_request)
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def transform(
output_filter=None,
join_source=None,
experiment_config=None,
model_client_config=None,
wait=False,
logs=False,
):
Expand Down Expand Up @@ -172,6 +173,10 @@ def transform(
Dictionary contains three optional keys,
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
(default: ``None``).
model_client_config (dict[str, str]): Model configuration.
Dictionary contains two optional keys,
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
(default: ``None``).
wait (bool): Whether the call should wait until the job completes
(default: False).
logs (bool): Whether to show the logs produced by the job.
Expand Down Expand Up @@ -208,6 +213,7 @@ def transform(
output_filter,
join_source,
experiment_config,
model_client_config,
)

if wait:
Expand Down Expand Up @@ -342,6 +348,7 @@ def start_new(
output_filter,
join_source,
experiment_config,
model_client_config,
):
"""
Args:
Expand All @@ -355,6 +362,7 @@ def start_new(
output_filter:
join_source:
experiment_config:
model_client_config:
"""
config = _TransformJob._load_config(
data, data_type, content_type, compression_type, split_type, transformer
Expand All @@ -374,6 +382,7 @@ def start_new(
output_config=config["output_config"],
resource_config=config["resource_config"],
experiment_config=experiment_config,
model_client_config=model_client_config,
tags=transformer.tags,
data_processing=data_processing,
)
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,26 @@ def test_transform_mxnet_tags(
assert tags == model_tags


def test_transform_model_client_config(
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
):
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
transformer = mxnet_estimator.transformer(1, cpu_instance_type)
transformer.transform(
mxnet_transform_input, content_type="text/csv", model_client_config=model_client_config
)

with timeout_and_delete_model_with_transformer(
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
):
transformer.wait()
transform_job_desc = sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=transformer.latest_transform_job.name
)

assert model_client_config == transform_job_desc["ModelClientConfig"]


def test_transform_byo_estimator(sagemaker_session, cpu_instance_type):
data_path = os.path.join(DATA_DIR, "one_p_mnist")
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def test_s3_input_all_arguments():
"TrialName": "dummyT",
"TrialComponentDisplayName": "dummyTC",
}
MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60}

DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
"OutputDataConfig": {"S3OutputPath": S3_OUTPUT},
Expand Down Expand Up @@ -1258,6 +1259,7 @@ def test_transform_pack_to_request(sagemaker_session):
output_config=out_config,
resource_config=resource_config,
experiment_config=None,
model_client_config=None,
tags=None,
data_processing=data_processing,
)
Expand All @@ -1283,6 +1285,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
output_config={},
resource_config={},
experiment_config=EXPERIMENT_CONFIG,
model_client_config=MODEL_CLIENT_CONFIG,
tags=TAGS,
data_processing=None,
)
Expand All @@ -1294,6 +1297,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
assert actual_args["Environment"] == env
assert actual_args["Tags"] == TAGS
assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG
assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG


@patch("sys.stdout", new_callable=io.BytesIO if six.PY2 else io.StringIO)
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_transform_with_all_params(start_new_job, transformer):
"TrialName": "t",
"TrialComponentDisplayName": "tc",
}
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}

transformer.transform(
DATA,
Expand All @@ -179,6 +180,7 @@ def test_transform_with_all_params(start_new_job, transformer):
output_filter=output_filter,
join_source=join_source,
experiment_config=experiment_config,
model_client_config=model_client_config,
)

assert transformer._current_job_name == JOB_NAME
Expand All @@ -194,6 +196,7 @@ def test_transform_with_all_params(start_new_job, transformer):
output_filter,
join_source,
experiment_config,
model_client_config,
)


Expand Down Expand Up @@ -428,6 +431,8 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
split_type = "Line"
io_filter = "$"
join_source = "Input"
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}

job = _TransformJob.start_new(
transformer=transformer,
data=DATA,
Expand All @@ -439,6 +444,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
output_filter=io_filter,
join_source=join_source,
experiment_config={"ExperimentName": "exp"},
model_client_config=model_client_config,
)

assert job.sagemaker_session == sagemaker_session
Expand All @@ -460,6 +466,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
output_config=output_config,
resource_config=resource_config,
experiment_config={"ExperimentName": "exp"},
model_client_config=model_client_config,
tags=tags,
data_processing=prepare_data_processing.return_value,
)
Expand Down