10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
+ from __future__ import absolute_import
13
14
import logging
14
15
15
16
import json
44
45
@pytest .fixture (name = 'sagemaker_session' )
45
46
def fixture_sagemaker_session ():
46
47
boto_mock = Mock (name = 'boto_session' , region_name = REGION )
47
- ims = Mock (name = 'sagemaker_session' , boto_session = boto_mock )
48
- ims .sagemaker_client .describe_training_job = Mock (return_value = {'ModelArtifacts' :
49
- {'S3ModelArtifacts' : 's3://m/m.tar.gz' }})
50
- ims .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
51
- ims .expand_role = Mock (name = "expand_role" , return_value = ROLE )
52
- return ims
48
+ session = Mock (name = 'sagemaker_session' , boto_session = boto_mock ,
49
+ boto_region_name = REGION , config = None , local_mode = False )
50
+
51
+ describe = {'ModelArtifacts' : {'S3ModelArtifacts' : 's3://m/m.tar.gz' }}
52
+ session .sagemaker_client .describe_training_job = Mock (return_value = describe )
53
+ session .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
54
+ session .expand_role = Mock (name = "expand_role" , return_value = ROLE )
55
+ return session
53
56
54
57
55
58
def _get_full_cpu_image_uri (version , py_version = PYTHON_VERSION ):
@@ -75,39 +78,42 @@ def _pytorch_estimator(sagemaker_session, framework_version=defaults.PYTORCH_VER
75
78
76
79
77
80
def _create_train_job (version ):
78
- return {'image' : _get_full_cpu_image_uri ( version ),
79
- 'input_mode ' : 'File' ,
80
- 'input_config ' : [{
81
- 'ChannelName ' : 'training' ,
82
- 'DataSource ' : {
83
- 'S3DataSource ' : {
84
- 'S3DataDistributionType ' : 'FullyReplicated' ,
85
- 'S3DataType ' : 'S3Prefix'
86
- }
81
+ return {
82
+ 'image ' : _get_full_cpu_image_uri ( version ) ,
83
+ 'input_mode ' : 'File' ,
84
+ 'input_config ' : [{
85
+ 'ChannelName ' : 'training' ,
86
+ 'DataSource ' : {
87
+ 'S3DataSource ' : {
88
+ 'S3DataDistributionType ' : 'FullyReplicated' ,
89
+ 'S3DataType' : 'S3Prefix'
87
90
}
88
- }],
89
- 'role' : ROLE ,
90
- 'job_name' : JOB_NAME ,
91
- 'output_config' : {
92
- 'S3OutputPath' : 's3://{}/' .format (BUCKET_NAME ),
93
- },
94
- 'resource_config' : {
95
- 'InstanceType' : 'ml.c4.4xlarge' ,
96
- 'InstanceCount' : 1 ,
97
- 'VolumeSizeInGB' : 30 ,
98
- },
99
- 'hyperparameters' : {
100
- 'sagemaker_program' : json .dumps ('dummy_script.py' ),
101
- 'sagemaker_enable_cloudwatch_metrics' : 'false' ,
102
- 'sagemaker_container_log_level' : str (logging .INFO ),
103
- 'sagemaker_job_name' : json .dumps (JOB_NAME ),
104
- 'sagemaker_submit_directory' :
105
- json .dumps ('s3://{}/{}/source/sourcedir.tar.gz' .format (BUCKET_NAME , JOB_NAME )),
106
- 'sagemaker_region' : '"us-west-2"'
107
- },
108
- 'stop_condition' : {
109
- 'MaxRuntimeInSeconds' : 24 * 60 * 60
110
- }}
91
+ }
92
+ }],
93
+ 'role' : ROLE ,
94
+ 'job_name' : JOB_NAME ,
95
+ 'output_config' : {
96
+ 'S3OutputPath' : 's3://{}/' .format (BUCKET_NAME ),
97
+ },
98
+ 'resource_config' : {
99
+ 'InstanceType' : 'ml.c4.4xlarge' ,
100
+ 'InstanceCount' : 1 ,
101
+ 'VolumeSizeInGB' : 30 ,
102
+ },
103
+ 'hyperparameters' : {
104
+ 'sagemaker_program' : json .dumps ('dummy_script.py' ),
105
+ 'sagemaker_enable_cloudwatch_metrics' : 'false' ,
106
+ 'sagemaker_container_log_level' : str (logging .INFO ),
107
+ 'sagemaker_job_name' : json .dumps (JOB_NAME ),
108
+ 'sagemaker_submit_directory' :
109
+ json .dumps ('s3://{}/{}/source/sourcedir.tar.gz' .format (BUCKET_NAME , JOB_NAME )),
110
+ 'sagemaker_region' : '"us-west-2"'
111
+ },
112
+ 'stop_condition' : {
113
+ 'MaxRuntimeInSeconds' : 24 * 60 * 60
114
+ },
115
+ 'tags' : None
116
+ }
111
117
112
118
113
119
def test_create_model (sagemaker_session , pytorch_version ):
0 commit comments