Skip to content

Commit 83ca2de

Browse files
knareshpengk19
authored andcommitted
feature: Add DataProcessing Fields for Batch Transform (aws#827)
1 parent 8b4573c commit 83ca2de

File tree

6 files changed

+79
-16
lines changed

6 files changed

+79
-16
lines changed

src/sagemaker/session.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def stop_tuning_job(self, name):
499499
raise
500500

501501
def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env,
502-
input_config, output_config, resource_config, tags):
502+
input_config, output_config, resource_config, tags, data_processing):
503503
"""Create an Amazon SageMaker transform job.
504504
505505
Args:
@@ -514,8 +514,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
514514
input_config (dict): A dictionary describing the input data (and its location) for the job.
515515
output_config (dict): A dictionary describing the output location for the job.
516516
resource_config (dict): A dictionary describing the resources to complete the job.
517-
tags (list[dict]): List of tags for labeling a training job. For more, see
518-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
517+
tags (list[dict]): List of tags for labeling a transform job.
518+
data_processing(dict): A dictionary describing config for combining the input data and transformed data.
519+
For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
519520
"""
520521
transform_request = {
521522
'TransformJobName': job_name,
@@ -540,6 +541,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
540541
if tags is not None:
541542
transform_request['Tags'] = tags
542543

544+
if data_processing is not None:
545+
transform_request['DataProcessing'] = data_processing
546+
543547
LOGGER.info('Creating transform job with name: {}'.format(job_name))
544548
LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4)))
545549
self.sagemaker_client.create_transform_job(**transform_request)

src/sagemaker/transformer.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
7979
self.sagemaker_session = sagemaker_session or Session()
8080

8181
def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
82-
job_name=None):
82+
job_name=None, input_filter=None, output_filter=None, join_source=None):
8383
"""Start a new transform job.
8484
8585
Args:
@@ -97,6 +97,15 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
9797
split_type (str): The record delimiter for the input object (default: 'None').
9898
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
9999
job_name (str): job name (default: None). If not specified, one will be generated.
100+
input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for
101+
inference. If you omit the field, it gets the value '$', representing the entire input.
102+
Some examples: "$[1:]", "$.features"(default: None).
103+
output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output.
104+
Some examples: "$[1:]", "$.prediction" (default: None).
105+
join_source (str): The source of data to be joined to the transform output. It can be set to 'Input'
106+
meaning the entire input record will be joined to the inference result.
107+
You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
108+
Valid values: Input, None.
100109
"""
101110
local_mode = self.sagemaker_session.local_mode
102111
if not local_mode and not data.startswith('s3://'):
@@ -116,7 +125,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
116125
self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name)
117126

118127
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
119-
split_type)
128+
split_type, input_filter, output_filter, join_source)
120129

121130
def delete_model(self):
122131
"""Delete the corresponding SageMaker model for this Transformer.
@@ -214,16 +223,19 @@ def _prepare_init_params_from_job_description(cls, job_details):
214223

215224
class _TransformJob(_Job):
216225
@classmethod
217-
def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type):
226+
def start_new(cls, transformer, data, data_type, content_type, compression_type,
227+
split_type, input_filter, output_filter, join_source):
218228
config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer)
229+
data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source)
219230

220231
transformer.sagemaker_session.transform(job_name=transformer._current_job_name,
221232
model_name=transformer.model_name, strategy=transformer.strategy,
222233
max_concurrent_transforms=transformer.max_concurrent_transforms,
223234
max_payload=transformer.max_payload, env=transformer.env,
224235
input_config=config['input_config'],
225236
output_config=config['output_config'],
226-
resource_config=config['resource_config'], tags=transformer.tags)
237+
resource_config=config['resource_config'],
238+
tags=transformer.tags, data_processing=data_processing)
227239

228240
return cls(transformer.sagemaker_session, transformer._current_job_name)
229241

@@ -287,3 +299,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
287299
config['VolumeKmsKeyId'] = volume_kms_key
288300

289301
return config
302+
303+
@staticmethod
304+
def _prepare_data_processing(input_filter, output_filter, join_source):
305+
config = {}
306+
307+
if input_filter is not None:
308+
config['InputFilter'] = input_filter
309+
310+
if output_filter is not None:
311+
config['OutputFilter'] = output_filter
312+
313+
if join_source is not None:
314+
config['JoinSource'] = join_source
315+
316+
if len(config) == 0:
317+
return None
318+
319+
return config

tests/integ/kms_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _create_kms_key(kms_client,
6868
role_arn=role_arn,
6969
sagemaker_role=sagemaker_role)
7070
else:
71-
principal = "{account_id}".format(account_id=account_id)
71+
principal = '"{account_id}"'.format(account_id=account_id)
7272

7373
response = kms_client.create_key(
7474
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role),

tests/integ/test_transformer.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,19 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):
5454
key_prefix=transform_input_key_prefix)
5555

5656
kms_key_arn = get_or_create_kms_key(sagemaker_session)
57+
output_filter = "$"
5758

58-
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
59+
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn,
60+
input_filter=None, output_filter=output_filter,
61+
join_source=None)
5962
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
6063
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
6164
transformer.wait()
6265

6366
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
6467
TransformJobName=transformer.latest_transform_job.name)
6568
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId']
69+
assert output_filter == job_desc['DataProcessing']['OutputFilter']
6670

6771

6872
@pytest.mark.canary_quick
@@ -232,7 +236,9 @@ def test_transform_byo_estimator(sagemaker_session):
232236
assert tags == model_tags
233237

234238

235-
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
239+
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
240+
input_filter=None, output_filter=None, join_source=None):
236241
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
237-
transformer.transform(transform_input, content_type='text/csv')
242+
transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter,
243+
output_filter=output_filter, join_source=join_source)
238244
return transformer

tests/unit/test_session.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def test_transform_pack_to_request(sagemaker_session):
588588

589589
sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None,
590590
max_payload=None, env=None, input_config=in_config, output_config=out_config,
591-
resource_config=resource_config, tags=None)
591+
resource_config=resource_config, tags=None, data_processing=None)
592592

593593
_, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0]
594594
assert actual_args == expected_args
@@ -603,7 +603,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
603603
sagemaker_session.transform(job_name=JOB_NAME, model_name='my-model', strategy=strategy,
604604
max_concurrent_transforms=max_concurrent_transforms,
605605
env=env, max_payload=max_payload, input_config={}, output_config={},
606-
resource_config={}, tags=TAGS)
606+
resource_config={}, tags=TAGS, data_processing=None)
607607

608608
_, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0]
609609
assert actual_args['BatchStrategy'] == strategy

tests/unit/test_transformer.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,18 @@ def test_transform_with_all_params(start_new_job, transformer):
9898
content_type = 'text/csv'
9999
compression = 'Gzip'
100100
split = 'Line'
101+
input_filter = "$.feature"
102+
output_filter = "$['sagemaker_output', 'id']"
103+
join_source = "Input"
101104

102105
transformer.transform(DATA, S3_DATA_TYPE, content_type=content_type, compression_type=compression, split_type=split,
103-
job_name=JOB_NAME)
106+
job_name=JOB_NAME, input_filter=input_filter, output_filter=output_filter,
107+
join_source=join_source)
104108

105109
assert transformer._current_job_name == JOB_NAME
106110
assert transformer.output_path == OUTPUT_PATH
107-
start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, split)
111+
start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression,
112+
split, input_filter, output_filter, join_source)
108113

109114

110115
@patch('sagemaker.transformer.name_from_base')
@@ -300,7 +305,8 @@ def test_start_new(transformer, sagemaker_session):
300305
transformer._current_job_name = JOB_NAME
301306

302307
job = _TransformJob(sagemaker_session, JOB_NAME)
303-
started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None)
308+
started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None,
309+
None, None, None)
304310

305311
assert started_job.sagemaker_session == sagemaker_session
306312
sagemaker_session.transform.assert_called_once()
@@ -392,6 +398,23 @@ def test_prepare_resource_config():
392398
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}
393399

394400

401+
def test_data_processing_config():
402+
actual_config = _TransformJob._prepare_data_processing("$", None, None)
403+
assert actual_config == {'InputFilter': "$"}
404+
405+
actual_config = _TransformJob._prepare_data_processing(None, "$", None)
406+
assert actual_config == {'OutputFilter': "$"}
407+
408+
actual_config = _TransformJob._prepare_data_processing(None, None, "Input")
409+
assert actual_config == {'JoinSource': "Input"}
410+
411+
actual_config = _TransformJob._prepare_data_processing("$[0]", "$[1]", "Input")
412+
assert actual_config == {'InputFilter': "$[0]", 'OutputFilter': "$[1]", 'JoinSource': "Input"}
413+
414+
actual_config = _TransformJob._prepare_data_processing(None, None, None)
415+
assert actual_config is None
416+
417+
395418
def test_transform_job_wait(sagemaker_session):
396419
job = _TransformJob(sagemaker_session, JOB_NAME)
397420
job.wait()

0 commit comments

Comments
 (0)