@@ -96,6 +96,64 @@ def test_transformer_fails_without_model():
96
96
)
97
97
98
98
99
+ def test_transformer_init (sagemaker_session ):
100
+ transformer = Transformer (
101
+ MODEL_NAME , INSTANCE_COUNT , INSTANCE_TYPE , sagemaker_session = sagemaker_session
102
+ )
103
+
104
+ assert transformer .model_name == MODEL_NAME
105
+ assert transformer .instance_count == INSTANCE_COUNT
106
+ assert transformer .instance_type == INSTANCE_TYPE
107
+ assert transformer .sagemaker_session == sagemaker_session
108
+
109
+ assert transformer ._current_job_name is None
110
+ assert transformer .latest_transform_job is None
111
+ assert transformer ._reset_output_path is False
112
+
113
+
114
+ def test_transformer_init_optional_params (sagemaker_session ):
115
+ strategy = "MultiRecord"
116
+ assemble_with = "Line"
117
+ accept = "text/csv"
118
+ max_concurrent_transforms = 100
119
+ max_payload = 100
120
+ tags = {"Key" : "foo" , "Value" : "bar" }
121
+ env = {"FOO" : "BAR" }
122
+
123
+ transformer = Transformer (
124
+ MODEL_NAME ,
125
+ INSTANCE_COUNT ,
126
+ INSTANCE_TYPE ,
127
+ strategy = strategy ,
128
+ assemble_with = assemble_with ,
129
+ output_path = OUTPUT_PATH ,
130
+ output_kms_key = KMS_KEY_ID ,
131
+ accept = accept ,
132
+ max_concurrent_transforms = max_concurrent_transforms ,
133
+ max_payload = max_payload ,
134
+ tags = tags ,
135
+ env = env ,
136
+ base_transform_job_name = JOB_NAME ,
137
+ sagemaker_session = sagemaker_session ,
138
+ volume_kms_key = KMS_KEY_ID ,
139
+ )
140
+
141
+ assert transformer .model_name == MODEL_NAME
142
+ assert transformer .strategy == strategy
143
+ assert transformer .env == env
144
+ assert transformer .output_path == OUTPUT_PATH
145
+ assert transformer .output_kms_key == KMS_KEY_ID
146
+ assert transformer .accept == accept
147
+ assert transformer .assemble_with == assemble_with
148
+ assert transformer .instance_count == INSTANCE_COUNT
149
+ assert transformer .instance_type == INSTANCE_TYPE
150
+ assert transformer .volume_kms_key == KMS_KEY_ID
151
+ assert transformer .max_concurrent_transforms == max_concurrent_transforms
152
+ assert transformer .max_payload == max_payload
153
+ assert transformer .tags == tags
154
+ assert transformer .base_transform_job_name == JOB_NAME
155
+
156
+
99
157
@patch ("sagemaker.transformer._TransformJob.start_new" )
100
158
def test_transform_with_all_params (start_new_job , transformer ):
101
159
content_type = "text/csv"
@@ -333,29 +391,78 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
333
391
334
392
335
393
# _TransformJob tests
336
- def test_start_new (transformer , sagemaker_session ):
394
+ @patch ("sagemaker.transformer._TransformJob._load_config" )
395
+ @patch ("sagemaker.transformer._TransformJob._prepare_data_processing" )
396
+ def test_start_new (prepare_data_processing , load_config , sagemaker_session ):
397
+ input_config = "input"
398
+ output_config = "output"
399
+ resource_config = "resource"
400
+ load_config .return_value = {
401
+ "input_config" : input_config ,
402
+ "output_config" : output_config ,
403
+ "resource_config" : resource_config ,
404
+ }
405
+
406
+ strategy = "MultiRecord"
407
+ max_concurrent_transforms = 100
408
+ max_payload = 100
409
+ tags = {"Key" : "foo" , "Value" : "bar" }
410
+ env = {"FOO" : "BAR" }
411
+
412
+ transformer = Transformer (
413
+ MODEL_NAME ,
414
+ INSTANCE_COUNT ,
415
+ INSTANCE_TYPE ,
416
+ strategy = strategy ,
417
+ output_path = OUTPUT_PATH ,
418
+ max_concurrent_transforms = max_concurrent_transforms ,
419
+ max_payload = max_payload ,
420
+ tags = tags ,
421
+ env = env ,
422
+ sagemaker_session = sagemaker_session ,
423
+ )
337
424
transformer ._current_job_name = JOB_NAME
338
425
339
- job = _TransformJob (sagemaker_session , JOB_NAME )
340
- started_job = job .start_new (
341
- transformer ,
342
- DATA ,
343
- S3_DATA_TYPE ,
344
- None ,
345
- None ,
346
- None ,
347
- None ,
348
- None ,
349
- None ,
350
- {"ExperimentName" : "exp" },
426
+ content_type = "text/csv"
427
+ compression_type = "Gzip"
428
+ split_type = "Line"
429
+ io_filter = "$"
430
+ join_source = "Input"
431
+ job = _TransformJob .start_new (
432
+ transformer = transformer ,
433
+ data = DATA ,
434
+ data_type = S3_DATA_TYPE ,
435
+ content_type = content_type ,
436
+ compression_type = compression_type ,
437
+ split_type = split_type ,
438
+ input_filter = io_filter ,
439
+ output_filter = io_filter ,
440
+ join_source = join_source ,
441
+ experiment_config = {"ExperimentName" : "exp" },
351
442
)
352
443
353
- assert started_job .sagemaker_session == sagemaker_session
354
- sagemaker_session . transform . assert_called_once ()
444
+ assert job .sagemaker_session == sagemaker_session
445
+ assert job . job_name == JOB_NAME
355
446
356
- called_args = sagemaker_session .transform .call_args
447
+ load_config .assert_called_with (
448
+ DATA , S3_DATA_TYPE , content_type , compression_type , split_type , transformer
449
+ )
450
+ prepare_data_processing .assert_called_with (io_filter , io_filter , join_source )
357
451
358
- assert called_args [1 ]["experiment_config" ] == {"ExperimentName" : "exp" }
452
+ sagemaker_session .transform .assert_called_with (
453
+ job_name = JOB_NAME ,
454
+ model_name = MODEL_NAME ,
455
+ strategy = strategy ,
456
+ max_concurrent_transforms = max_concurrent_transforms ,
457
+ max_payload = max_payload ,
458
+ env = env ,
459
+ input_config = input_config ,
460
+ output_config = output_config ,
461
+ resource_config = resource_config ,
462
+ experiment_config = {"ExperimentName" : "exp" },
463
+ tags = tags ,
464
+ data_processing = prepare_data_processing .return_value ,
465
+ )
359
466
360
467
361
468
def test_load_config (transformer ):
0 commit comments