Skip to content

Commit e4f71ee

Browse files
authored
feature: Add ModelClientConfig Fields for Batch Transform (#1523)
1 parent 78b683d commit e4f71ee

File tree

6 files changed

+48
-1
lines changed

6 files changed

+48
-1
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def read_version():
3434

3535
# Declare minimal set for installation
3636
required_packages = [
37-
"boto3>=1.13.24",
37+
"boto3>=1.14.12",
3838
"numpy>=1.9.0",
3939
"protobuf>=3.1",
4040
"scipy>=0.19.0",

src/sagemaker/session.py

+7
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,7 @@ def transform(
19961996
experiment_config,
19971997
tags,
19981998
data_processing,
1999+
model_client_config=None,
19992000
):
20002001
"""Create an Amazon SageMaker transform job.
20012002
@@ -2020,6 +2021,9 @@ def transform(
20202021
data_processing(dict): A dictionary describing config for combining the input data and
20212022
transformed data. For more, see
20222023
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2024+
model_client_config (dict): A dictionary describing the model configuration for the
2025+
job. Dictionary contains two optional keys,
2026+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
20232027
"""
20242028
transform_request = {
20252029
"TransformJobName": job_name,
@@ -2050,6 +2054,9 @@ def transform(
20502054
if experiment_config and len(experiment_config) > 0:
20512055
transform_request["ExperimentConfig"] = experiment_config
20522056

2057+
if model_client_config and len(model_client_config) > 0:
2058+
transform_request["ModelClientConfig"] = model_client_config
2059+
20532060
LOGGER.info("Creating transform job with name: %s", job_name)
20542061
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
20552062
self.sagemaker_client.create_transform_job(**transform_request)

src/sagemaker/transformer.py

+9
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def transform(
120120
output_filter=None,
121121
join_source=None,
122122
experiment_config=None,
123+
model_client_config=None,
123124
wait=False,
124125
logs=False,
125126
):
@@ -172,6 +173,10 @@ def transform(
172173
Dictionary contains three optional keys,
173174
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
174175
(default: ``None``).
176+
model_client_config (dict[str, str]): Model configuration.
177+
Dictionary contains two optional keys,
178+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
179+
(default: ``None``).
175180
wait (bool): Whether the call should wait until the job completes
176181
(default: False).
177182
logs (bool): Whether to show the logs produced by the job.
@@ -208,6 +213,7 @@ def transform(
208213
output_filter,
209214
join_source,
210215
experiment_config,
216+
model_client_config,
211217
)
212218

213219
if wait:
@@ -342,6 +348,7 @@ def start_new(
342348
output_filter,
343349
join_source,
344350
experiment_config,
351+
model_client_config,
345352
):
346353
"""
347354
Args:
@@ -355,6 +362,7 @@ def start_new(
355362
output_filter:
356363
join_source:
357364
experiment_config:
365+
model_client_config:
358366
"""
359367
config = _TransformJob._load_config(
360368
data, data_type, content_type, compression_type, split_type, transformer
@@ -374,6 +382,7 @@ def start_new(
374382
output_config=config["output_config"],
375383
resource_config=config["resource_config"],
376384
experiment_config=experiment_config,
385+
model_client_config=model_client_config,
377386
tags=transformer.tags,
378387
data_processing=data_processing,
379388
)

tests/integ/test_transformer.py

+20
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,26 @@ def test_transform_mxnet_tags(
230230
assert tags == model_tags
231231

232232

233+
def test_transform_model_client_config(
234+
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
235+
):
236+
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
237+
transformer = mxnet_estimator.transformer(1, cpu_instance_type)
238+
transformer.transform(
239+
mxnet_transform_input, content_type="text/csv", model_client_config=model_client_config
240+
)
241+
242+
with timeout_and_delete_model_with_transformer(
243+
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
244+
):
245+
transformer.wait()
246+
transform_job_desc = sagemaker_session.sagemaker_client.describe_transform_job(
247+
TransformJobName=transformer.latest_transform_job.name
248+
)
249+
250+
assert model_client_config == transform_job_desc["ModelClientConfig"]
251+
252+
233253
def test_transform_byo_estimator(sagemaker_session, cpu_instance_type):
234254
data_path = os.path.join(DATA_DIR, "one_p_mnist")
235255
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}

tests/unit/test_session.py

+4
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def test_s3_input_all_arguments():
576576
"TrialName": "dummyT",
577577
"TrialComponentDisplayName": "dummyTC",
578578
}
579+
MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60}
579580

580581
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
581582
"OutputDataConfig": {"S3OutputPath": S3_OUTPUT},
@@ -1258,6 +1259,7 @@ def test_transform_pack_to_request(sagemaker_session):
12581259
output_config=out_config,
12591260
resource_config=resource_config,
12601261
experiment_config=None,
1262+
model_client_config=None,
12611263
tags=None,
12621264
data_processing=data_processing,
12631265
)
@@ -1283,6 +1285,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
12831285
output_config={},
12841286
resource_config={},
12851287
experiment_config=EXPERIMENT_CONFIG,
1288+
model_client_config=MODEL_CLIENT_CONFIG,
12861289
tags=TAGS,
12871290
data_processing=None,
12881291
)
@@ -1294,6 +1297,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
12941297
assert actual_args["Environment"] == env
12951298
assert actual_args["Tags"] == TAGS
12961299
assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG
1300+
assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG
12971301

12981302

12991303
@patch("sys.stdout", new_callable=io.BytesIO if six.PY2 else io.StringIO)

tests/unit/test_transformer.py

+7
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def test_transform_with_all_params(start_new_job, transformer):
167167
"TrialName": "t",
168168
"TrialComponentDisplayName": "tc",
169169
}
170+
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
170171

171172
transformer.transform(
172173
DATA,
@@ -179,6 +180,7 @@ def test_transform_with_all_params(start_new_job, transformer):
179180
output_filter=output_filter,
180181
join_source=join_source,
181182
experiment_config=experiment_config,
183+
model_client_config=model_client_config,
182184
)
183185

184186
assert transformer._current_job_name == JOB_NAME
@@ -194,6 +196,7 @@ def test_transform_with_all_params(start_new_job, transformer):
194196
output_filter,
195197
join_source,
196198
experiment_config,
199+
model_client_config,
197200
)
198201

199202

@@ -428,6 +431,8 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
428431
split_type = "Line"
429432
io_filter = "$"
430433
join_source = "Input"
434+
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
435+
431436
job = _TransformJob.start_new(
432437
transformer=transformer,
433438
data=DATA,
@@ -439,6 +444,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
439444
output_filter=io_filter,
440445
join_source=join_source,
441446
experiment_config={"ExperimentName": "exp"},
447+
model_client_config=model_client_config,
442448
)
443449

444450
assert job.sagemaker_session == sagemaker_session
@@ -460,6 +466,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
460466
output_config=output_config,
461467
resource_config=resource_config,
462468
experiment_config={"ExperimentName": "exp"},
469+
model_client_config=model_client_config,
463470
tags=tags,
464471
data_processing=prepare_data_processing.return_value,
465472
)

0 commit comments

Comments
 (0)