Skip to content

Commit bca132d

Browse files
authored
infra: improve unit tests for creating Transformers and transform jobs (#1408)
1 parent 68020e7 commit bca132d

File tree

1 file changed

+124
-17
lines changed

1 file changed

+124
-17
lines changed

tests/unit/test_transformer.py

Lines changed: 124 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,64 @@ def test_transformer_fails_without_model():
9696
)
9797

9898

99+
def test_transformer_init(sagemaker_session):
100+
transformer = Transformer(
101+
MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
102+
)
103+
104+
assert transformer.model_name == MODEL_NAME
105+
assert transformer.instance_count == INSTANCE_COUNT
106+
assert transformer.instance_type == INSTANCE_TYPE
107+
assert transformer.sagemaker_session == sagemaker_session
108+
109+
assert transformer._current_job_name is None
110+
assert transformer.latest_transform_job is None
111+
assert transformer._reset_output_path is False
112+
113+
114+
def test_transformer_init_optional_params(sagemaker_session):
115+
strategy = "MultiRecord"
116+
assemble_with = "Line"
117+
accept = "text/csv"
118+
max_concurrent_transforms = 100
119+
max_payload = 100
120+
tags = {"Key": "foo", "Value": "bar"}
121+
env = {"FOO": "BAR"}
122+
123+
transformer = Transformer(
124+
MODEL_NAME,
125+
INSTANCE_COUNT,
126+
INSTANCE_TYPE,
127+
strategy=strategy,
128+
assemble_with=assemble_with,
129+
output_path=OUTPUT_PATH,
130+
output_kms_key=KMS_KEY_ID,
131+
accept=accept,
132+
max_concurrent_transforms=max_concurrent_transforms,
133+
max_payload=max_payload,
134+
tags=tags,
135+
env=env,
136+
base_transform_job_name=JOB_NAME,
137+
sagemaker_session=sagemaker_session,
138+
volume_kms_key=KMS_KEY_ID,
139+
)
140+
141+
assert transformer.model_name == MODEL_NAME
142+
assert transformer.strategy == strategy
143+
assert transformer.env == env
144+
assert transformer.output_path == OUTPUT_PATH
145+
assert transformer.output_kms_key == KMS_KEY_ID
146+
assert transformer.accept == accept
147+
assert transformer.assemble_with == assemble_with
148+
assert transformer.instance_count == INSTANCE_COUNT
149+
assert transformer.instance_type == INSTANCE_TYPE
150+
assert transformer.volume_kms_key == KMS_KEY_ID
151+
assert transformer.max_concurrent_transforms == max_concurrent_transforms
152+
assert transformer.max_payload == max_payload
153+
assert transformer.tags == tags
154+
assert transformer.base_transform_job_name == JOB_NAME
155+
156+
99157
@patch("sagemaker.transformer._TransformJob.start_new")
100158
def test_transform_with_all_params(start_new_job, transformer):
101159
content_type = "text/csv"
@@ -333,29 +391,78 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
333391

334392

335393
# _TransformJob tests
336-
def test_start_new(transformer, sagemaker_session):
394+
@patch("sagemaker.transformer._TransformJob._load_config")
395+
@patch("sagemaker.transformer._TransformJob._prepare_data_processing")
396+
def test_start_new(prepare_data_processing, load_config, sagemaker_session):
397+
input_config = "input"
398+
output_config = "output"
399+
resource_config = "resource"
400+
load_config.return_value = {
401+
"input_config": input_config,
402+
"output_config": output_config,
403+
"resource_config": resource_config,
404+
}
405+
406+
strategy = "MultiRecord"
407+
max_concurrent_transforms = 100
408+
max_payload = 100
409+
tags = {"Key": "foo", "Value": "bar"}
410+
env = {"FOO": "BAR"}
411+
412+
transformer = Transformer(
413+
MODEL_NAME,
414+
INSTANCE_COUNT,
415+
INSTANCE_TYPE,
416+
strategy=strategy,
417+
output_path=OUTPUT_PATH,
418+
max_concurrent_transforms=max_concurrent_transforms,
419+
max_payload=max_payload,
420+
tags=tags,
421+
env=env,
422+
sagemaker_session=sagemaker_session,
423+
)
337424
transformer._current_job_name = JOB_NAME
338425

339-
job = _TransformJob(sagemaker_session, JOB_NAME)
340-
started_job = job.start_new(
341-
transformer,
342-
DATA,
343-
S3_DATA_TYPE,
344-
None,
345-
None,
346-
None,
347-
None,
348-
None,
349-
None,
350-
{"ExperimentName": "exp"},
426+
content_type = "text/csv"
427+
compression_type = "Gzip"
428+
split_type = "Line"
429+
io_filter = "$"
430+
join_source = "Input"
431+
job = _TransformJob.start_new(
432+
transformer=transformer,
433+
data=DATA,
434+
data_type=S3_DATA_TYPE,
435+
content_type=content_type,
436+
compression_type=compression_type,
437+
split_type=split_type,
438+
input_filter=io_filter,
439+
output_filter=io_filter,
440+
join_source=join_source,
441+
experiment_config={"ExperimentName": "exp"},
351442
)
352443

353-
assert started_job.sagemaker_session == sagemaker_session
354-
sagemaker_session.transform.assert_called_once()
444+
assert job.sagemaker_session == sagemaker_session
445+
assert job.job_name == JOB_NAME
355446

356-
called_args = sagemaker_session.transform.call_args
447+
load_config.assert_called_with(
448+
DATA, S3_DATA_TYPE, content_type, compression_type, split_type, transformer
449+
)
450+
prepare_data_processing.assert_called_with(io_filter, io_filter, join_source)
357451

358-
assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"}
452+
sagemaker_session.transform.assert_called_with(
453+
job_name=JOB_NAME,
454+
model_name=MODEL_NAME,
455+
strategy=strategy,
456+
max_concurrent_transforms=max_concurrent_transforms,
457+
max_payload=max_payload,
458+
env=env,
459+
input_config=input_config,
460+
output_config=output_config,
461+
resource_config=resource_config,
462+
experiment_config={"ExperimentName": "exp"},
463+
tags=tags,
464+
data_processing=prepare_data_processing.return_value,
465+
)
359466

360467

361468
def test_load_config(transformer):

0 commit comments

Comments
 (0)