Skip to content

Commit 7b2dc1a

Browse files
committed
feature: Add ModelClientConfig Fields for Batch Transform
1 parent a66c48d commit 7b2dc1a

File tree

5 files changed

+47
-0
lines changed

5 files changed

+47
-0
lines changed

src/sagemaker/session.py

+7
Original file line numberDiff line numberDiff line change
@@ -1980,6 +1980,7 @@ def transform(
19801980
output_config,
19811981
resource_config,
19821982
experiment_config,
1983+
model_client_config,
19831984
tags,
19841985
data_processing,
19851986
):
@@ -2002,6 +2003,9 @@ def transform(
20022003
experiment_config (dict): A dictionary describing the experiment configuration for the
20032004
job. Dictionary contains three optional keys,
20042005
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
2006+
model_client_config (dict): A dictionary describing the model configuration for the
2007+
job. Dictionary contains two optional keys,
2008+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
20052009
tags (list[dict]): List of tags for labeling a transform job.
20062010
data_processing(dict): A dictionary describing config for combining the input data and
20072011
transformed data. For more, see
@@ -2036,6 +2040,9 @@ def transform(
20362040
if experiment_config and len(experiment_config) > 0:
20372041
transform_request["ExperimentConfig"] = experiment_config
20382042

2043+
if model_client_config and len(model_client_config) > 0:
2044+
transform_request["ModelClientConfig"] = model_client_config
2045+
20392046
LOGGER.info("Creating transform job with name: %s", job_name)
20402047
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
20412048
self.sagemaker_client.create_transform_job(**transform_request)

src/sagemaker/transformer.py

+9
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def transform(
120120
output_filter=None,
121121
join_source=None,
122122
experiment_config=None,
123+
model_client_config=None,
123124
wait=False,
124125
logs=False,
125126
):
@@ -171,6 +172,10 @@ def transform(
171172
Dictionary contains three optional keys,
172173
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
173174
(default: ``None``).
175+
model_client_config (dict[str, str]): Model configuration.
176+
Dictionary contains two optional keys,
177+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
178+
(default: ``None``).
174179
wait (bool): Whether the call should wait until the job completes
175180
(default: False).
176181
logs (bool): Whether to show the logs produced by the job.
@@ -207,6 +212,7 @@ def transform(
207212
output_filter,
208213
join_source,
209214
experiment_config,
215+
model_client_config,
210216
)
211217

212218
if wait:
@@ -341,6 +347,7 @@ def start_new(
341347
output_filter,
342348
join_source,
343349
experiment_config,
350+
model_client_config,
344351
):
345352
"""
346353
Args:
@@ -354,6 +361,7 @@ def start_new(
354361
output_filter:
355362
join_source:
356363
experiment_config:
364+
model_client_config:
357365
"""
358366
config = _TransformJob._load_config(
359367
data, data_type, content_type, compression_type, split_type, transformer
@@ -373,6 +381,7 @@ def start_new(
373381
output_config=config["output_config"],
374382
resource_config=config["resource_config"],
375383
experiment_config=experiment_config,
384+
model_client_config=model_client_config,
376385
tags=transformer.tags,
377386
data_processing=data_processing,
378387
)

tests/integ/test_transformer.py

+20
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,26 @@ def test_transform_mxnet_tags(
229229
assert tags == model_tags
230230

231231

232+
def test_transform_model_client_config(
233+
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
234+
):
235+
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
236+
transformer = mxnet_estimator.transformer(1, cpu_instance_type)
237+
transformer.transform(
238+
mxnet_transform_input, content_type="text/csv", model_client_config=model_client_config
239+
)
240+
241+
with timeout_and_delete_model_with_transformer(
242+
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
243+
):
244+
transformer.wait()
245+
transform_job_desc = sagemaker_session.sagemaker_client.describe_transform_job(
246+
TransformJobName=transformer.latest_transform_job.name
247+
)
248+
249+
assert model_client_config == transform_job_desc["ModelClientConfig"]
250+
251+
232252
def test_transform_byo_estimator(sagemaker_session, cpu_instance_type):
233253
data_path = os.path.join(DATA_DIR, "one_p_mnist")
234254
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}

tests/unit/test_session.py

+4
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def test_s3_input_all_arguments():
562562
"TrialName": "dummyT",
563563
"TrialComponentDisplayName": "dummyTC",
564564
}
565+
MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60}
565566

566567
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
567568
"OutputDataConfig": {"S3OutputPath": S3_OUTPUT},
@@ -1244,6 +1245,7 @@ def test_transform_pack_to_request(sagemaker_session):
12441245
output_config=out_config,
12451246
resource_config=resource_config,
12461247
experiment_config=None,
1248+
model_client_config=None,
12471249
tags=None,
12481250
data_processing=data_processing,
12491251
)
@@ -1269,6 +1271,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
12691271
output_config={},
12701272
resource_config={},
12711273
experiment_config=EXPERIMENT_CONFIG,
1274+
model_client_config=MODEL_CLIENT_CONFIG,
12721275
tags=TAGS,
12731276
data_processing=None,
12741277
)
@@ -1280,6 +1283,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
12801283
assert actual_args["Environment"] == env
12811284
assert actual_args["Tags"] == TAGS
12821285
assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG
1286+
assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG
12831287

12841288

12851289
@patch("sys.stdout", new_callable=io.BytesIO if six.PY2 else io.StringIO)

tests/unit/test_transformer.py

+7
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def test_transform_with_all_params(start_new_job, transformer):
167167
"TrialName": "t",
168168
"TrialComponentDisplayName": "tc",
169169
}
170+
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
170171

171172
transformer.transform(
172173
DATA,
@@ -179,6 +180,7 @@ def test_transform_with_all_params(start_new_job, transformer):
179180
output_filter=output_filter,
180181
join_source=join_source,
181182
experiment_config=experiment_config,
183+
model_client_config=model_client_config,
182184
)
183185

184186
assert transformer._current_job_name == JOB_NAME
@@ -194,6 +196,7 @@ def test_transform_with_all_params(start_new_job, transformer):
194196
output_filter,
195197
join_source,
196198
experiment_config,
199+
model_client_config,
197200
)
198201

199202

@@ -428,6 +431,8 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
428431
split_type = "Line"
429432
io_filter = "$"
430433
join_source = "Input"
434+
model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2}
435+
431436
job = _TransformJob.start_new(
432437
transformer=transformer,
433438
data=DATA,
@@ -439,6 +444,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
439444
output_filter=io_filter,
440445
join_source=join_source,
441446
experiment_config={"ExperimentName": "exp"},
447+
model_client_config=model_client_config,
442448
)
443449

444450
assert job.sagemaker_session == sagemaker_session
@@ -460,6 +466,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session):
460466
output_config=output_config,
461467
resource_config=resource_config,
462468
experiment_config={"ExperimentName": "exp"},
469+
model_client_config=model_client_config,
463470
tags=tags,
464471
data_processing=prepare_data_processing.return_value,
465472
)

0 commit comments

Comments
 (0)