@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
79
79
self .sagemaker_session = sagemaker_session or Session ()
80
80
81
81
def transform (self , data , data_type = 'S3Prefix' , content_type = None , compression_type = None , split_type = None ,
82
- job_name = None ):
82
+ job_name = None , input_filter = None , output_filter = None , join_source = None ):
83
83
"""Start a new transform job.
84
84
85
85
Args:
@@ -97,6 +97,15 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
97
97
split_type (str): The record delimiter for the input object (default: 'None').
98
98
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
99
99
job_name (str): job name (default: None). If not specified, one will be generated.
100
+ input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for
101
+ inference. If you omit the field, it gets the value '$', representing the entire input.
102
+ Some examples: "$[1:]", "$.features"(default: None).
103
+ output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output.
104
+ Some examples: "$[1:]", "$.prediction" (default: None).
105
+ join_source (str): The source of data to be joined to the transform output. It can be set to 'Input'
106
+ meaning the entire input record will be joined to the inference result.
107
+ You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
108
+ Valid values: Input, None.
100
109
"""
101
110
local_mode = self .sagemaker_session .local_mode
102
111
if not local_mode and not data .startswith ('s3://' ):
@@ -116,7 +125,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
116
125
self .output_path = 's3://{}/{}' .format (self .sagemaker_session .default_bucket (), self ._current_job_name )
117
126
118
127
self .latest_transform_job = _TransformJob .start_new (self , data , data_type , content_type , compression_type ,
119
- split_type )
128
+ split_type , input_filter , output_filter , join_source )
120
129
121
130
def delete_model (self ):
122
131
"""Delete the corresponding SageMaker model for this Transformer.
@@ -214,16 +223,19 @@ def _prepare_init_params_from_job_description(cls, job_details):
214
223
215
224
class _TransformJob (_Job ):
216
225
@classmethod
217
- def start_new (cls , transformer , data , data_type , content_type , compression_type , split_type ):
226
+ def start_new (cls , transformer , data , data_type , content_type , compression_type ,
227
+ split_type , input_filter , output_filter , join_source ):
218
228
config = _TransformJob ._load_config (data , data_type , content_type , compression_type , split_type , transformer )
229
+ data_processing = _TransformJob ._prepare_data_processing (input_filter , output_filter , join_source )
219
230
220
231
transformer .sagemaker_session .transform (job_name = transformer ._current_job_name ,
221
232
model_name = transformer .model_name , strategy = transformer .strategy ,
222
233
max_concurrent_transforms = transformer .max_concurrent_transforms ,
223
234
max_payload = transformer .max_payload , env = transformer .env ,
224
235
input_config = config ['input_config' ],
225
236
output_config = config ['output_config' ],
226
- resource_config = config ['resource_config' ], tags = transformer .tags )
237
+ resource_config = config ['resource_config' ],
238
+ tags = transformer .tags , data_processing = data_processing )
227
239
228
240
return cls (transformer .sagemaker_session , transformer ._current_job_name )
229
241
@@ -287,3 +299,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
287
299
config ['VolumeKmsKeyId' ] = volume_kms_key
288
300
289
301
return config
302
+
303
+ @staticmethod
304
+ def _prepare_data_processing (input_filter , output_filter , join_source ):
305
+ config = {}
306
+
307
+ if input_filter is not None :
308
+ config ['InputFilter' ] = input_filter
309
+
310
+ if output_filter is not None :
311
+ config ['OutputFilter' ] = output_filter
312
+
313
+ if join_source is not None :
314
+ config ['JoinSource' ] = join_source
315
+
316
+ if len (config ) == 0 :
317
+ return None
318
+
319
+ return config
0 commit comments