|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 | 15 | import pytest
|
16 |
| -from mock import Mock, patch |
| 16 | +from mock import Mock |
17 | 17 | from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator
|
18 | 18 |
|
19 | 19 | MODEL_DATA = "s3://bucket/model.tar.gz"
|
20 | 20 | MODEL_IMAGE = "mi"
|
21 | 21 | ENTRY_POINT = "blah.py"
|
22 | 22 |
|
23 |
| -TIMESTAMP = "2017-11-06-14:14:15.671" |
24 | 23 | BUCKET_NAME = "mybucket"
|
25 | 24 | INSTANCE_COUNT = 1
|
26 | 25 | INSTANCE_TYPE = "ml.c5.2xlarge"
|
|
32 | 31 | DEFAULT_OUTPUT_PATH = "s3://{}/".format(BUCKET_NAME)
|
33 | 32 | LOCAL_DATA_PATH = "file://data"
|
34 | 33 | DEFAULT_MAX_CANDIDATES = 500
|
35 |
| -DEFAULT_JOB_NAME = "automl-{}".format(TIMESTAMP) |
36 | 34 |
|
37 | 35 | JOB_NAME = "default-job-name"
|
38 | 36 | JOB_NAME_2 = "banana-auto-ml-job"
|
@@ -283,38 +281,34 @@ def test_auto_ml_additional_optional_params(sagemaker_session):
|
283 | 281 | }
|
284 | 282 |
|
285 | 283 |
|
286 |
| -@patch("time.strftime", return_value=TIMESTAMP) |
287 |
| -def test_auto_ml_default_fit(strftime, sagemaker_session): |
| 284 | +def test_auto_ml_default_fit(sagemaker_session): |
288 | 285 | auto_ml = AutoML(
|
289 | 286 | role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
|
290 | 287 | )
|
291 | 288 | inputs = DEFAULT_S3_INPUT_DATA
|
292 | 289 | auto_ml.fit(inputs)
|
293 | 290 | sagemaker_session.auto_ml.assert_called_once()
|
294 | 291 | _, args = sagemaker_session.auto_ml.call_args
|
295 |
| - assert args == { |
296 |
| - "input_config": [ |
297 |
| - { |
298 |
| - "DataSource": { |
299 |
| - "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
300 |
| - }, |
301 |
| - "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
302 |
| - } |
303 |
| - ], |
304 |
| - "output_config": {"S3OutputPath": DEFAULT_OUTPUT_PATH}, |
305 |
| - "auto_ml_job_config": { |
306 |
| - "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
307 |
| - "SecurityConfig": { |
308 |
| - "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
| 292 | + assert args["input_config"] == [ |
| 293 | + { |
| 294 | + "DataSource": { |
| 295 | + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
309 | 296 | },
|
| 297 | + "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
| 298 | + } |
| 299 | + ] |
| 300 | + assert args["output_config"] == {"S3OutputPath": DEFAULT_OUTPUT_PATH} |
| 301 | + assert args["auto_ml_job_config"] == { |
| 302 | + "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
| 303 | + "SecurityConfig": { |
| 304 | + "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
310 | 305 | },
|
311 |
| - "role": ROLE, |
312 |
| - "job_name": DEFAULT_JOB_NAME, |
313 |
| - "problem_type": None, |
314 |
| - "job_objective": None, |
315 |
| - "generate_candidate_definitions_only": GENERATE_CANDIDATE_DEFINITIONS_ONLY, |
316 |
| - "tags": None, |
317 | 306 | }
|
| 307 | + assert args["role"] == ROLE |
| 308 | + assert args["problem_type"] is None |
| 309 | + assert args["job_objective"] is None |
| 310 | + assert args["generate_candidate_definitions_only"] == GENERATE_CANDIDATE_DEFINITIONS_ONLY |
| 311 | + assert args["tags"] is None |
318 | 312 |
|
319 | 313 |
|
320 | 314 | def test_auto_ml_local_input(sagemaker_session):
|
|
0 commit comments