From 6ab1027379004a2511a63eac7abbb0f79e79416b Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 14 Apr 2020 11:21:30 -0700 Subject: [PATCH] infra: improve unit tests for creating Transformers and transform jobs --- tests/unit/test_transformer.py | 141 +++++++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 17 deletions(-) diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 8a5741892d..aeaa290827 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -96,6 +96,64 @@ def test_transformer_fails_without_model(): ) +def test_transformer_init(sagemaker_session): + transformer = Transformer( + MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + ) + + assert transformer.model_name == MODEL_NAME + assert transformer.instance_count == INSTANCE_COUNT + assert transformer.instance_type == INSTANCE_TYPE + assert transformer.sagemaker_session == sagemaker_session + + assert transformer._current_job_name is None + assert transformer.latest_transform_job is None + assert transformer._reset_output_path is False + + +def test_transformer_init_optional_params(sagemaker_session): + strategy = "MultiRecord" + assemble_with = "Line" + accept = "text/csv" + max_concurrent_transforms = 100 + max_payload = 100 + tags = {"Key": "foo", "Value": "bar"} + env = {"FOO": "BAR"} + + transformer = Transformer( + MODEL_NAME, + INSTANCE_COUNT, + INSTANCE_TYPE, + strategy=strategy, + assemble_with=assemble_with, + output_path=OUTPUT_PATH, + output_kms_key=KMS_KEY_ID, + accept=accept, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + tags=tags, + env=env, + base_transform_job_name=JOB_NAME, + sagemaker_session=sagemaker_session, + volume_kms_key=KMS_KEY_ID, + ) + + assert transformer.model_name == MODEL_NAME + assert transformer.strategy == strategy + assert transformer.env == env + assert transformer.output_path == OUTPUT_PATH + assert transformer.output_kms_key == KMS_KEY_ID + assert transformer.accept == accept + assert transformer.assemble_with == assemble_with + assert transformer.instance_count == INSTANCE_COUNT + assert transformer.instance_type == INSTANCE_TYPE + assert transformer.volume_kms_key == KMS_KEY_ID + assert transformer.max_concurrent_transforms == max_concurrent_transforms + assert transformer.max_payload == max_payload + assert transformer.tags == tags + assert transformer.base_transform_job_name == JOB_NAME + + @patch("sagemaker.transformer._TransformJob.start_new") def test_transform_with_all_params(start_new_job, transformer): content_type = "text/csv" @@ -333,29 +391,78 @@ def test_prepare_init_params_from_job_description_all_keys(transformer): # _TransformJob tests -def test_start_new(transformer, sagemaker_session): +@patch("sagemaker.transformer._TransformJob._load_config") +@patch("sagemaker.transformer._TransformJob._prepare_data_processing") +def test_start_new(prepare_data_processing, load_config, sagemaker_session): + input_config = "input" + output_config = "output" + resource_config = "resource" + load_config.return_value = { + "input_config": input_config, + "output_config": output_config, + "resource_config": resource_config, + } + + strategy = "MultiRecord" + max_concurrent_transforms = 100 + max_payload = 100 + tags = {"Key": "foo", "Value": "bar"} + env = {"FOO": "BAR"} + + transformer = Transformer( + MODEL_NAME, + INSTANCE_COUNT, + INSTANCE_TYPE, + strategy=strategy, + output_path=OUTPUT_PATH, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + tags=tags, + env=env, + sagemaker_session=sagemaker_session, + ) transformer._current_job_name = JOB_NAME - job = _TransformJob(sagemaker_session, JOB_NAME) - started_job = job.start_new( - transformer, - DATA, - S3_DATA_TYPE, - None, - None, - None, - None, - None, - None, - {"ExperimentName": "exp"}, + content_type = "text/csv" + compression_type = "Gzip" + split_type = "Line" + io_filter = "$" + join_source = "Input" + job = _TransformJob.start_new( + transformer=transformer, + data=DATA, + data_type=S3_DATA_TYPE, + content_type=content_type, + compression_type=compression_type, + split_type=split_type, + input_filter=io_filter, + output_filter=io_filter, + join_source=join_source, + experiment_config={"ExperimentName": "exp"}, ) - assert started_job.sagemaker_session == sagemaker_session - sagemaker_session.transform.assert_called_once() + assert job.sagemaker_session == sagemaker_session + assert job.job_name == JOB_NAME - called_args = sagemaker_session.transform.call_args + load_config.assert_called_with( + DATA, S3_DATA_TYPE, content_type, compression_type, split_type, transformer + ) + prepare_data_processing.assert_called_with(io_filter, io_filter, join_source) - assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"} + sagemaker_session.transform.assert_called_with( + job_name=JOB_NAME, + model_name=MODEL_NAME, + strategy=strategy, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=env, + input_config=input_config, + output_config=output_config, + resource_config=resource_config, + experiment_config={"ExperimentName": "exp"}, + tags=tags, + data_processing=prepare_data_processing.return_value, + ) def test_load_config(transformer):