Skip to content

Commit 0a655e8

Browse files
Merge pull request aws#9 from aws/refactor_job
Add job base class
2 parents 9b9272b + 1f230ac commit 0a655e8

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed

src/sagemaker/job.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from abc import abstractmethod
14+
from six import string_types
15+
16+
from sagemaker.session import s3_input
17+
18+
19+
class _Job(object):
20+
"""Handle creating, starting and waiting for Amazon SageMaker jobs to finish.
21+
22+
This class shouldn't be directly instantiated.
23+
24+
Subclasses must define a way to create, start and wait for an Amazon SageMaker job.
25+
"""
26+
27+
def __init__(self, sagemaker_session, job_name):
28+
self.sagemaker_session = sagemaker_session
29+
self.job_name = job_name
30+
31+
@abstractmethod
32+
def start_new(cls, estimator, inputs):
33+
"""Create a new Amazon SageMaker job from the estimator.
34+
35+
Args:
36+
estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
37+
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
38+
39+
Returns:
40+
sagemaker.job: Constructed object that captures all information about the started job.
41+
"""
42+
pass
43+
44+
@abstractmethod
45+
def wait(self):
46+
"""Wait for the Amazon SageMaker job to finish.
47+
"""
48+
pass
49+
50+
@staticmethod
51+
def _load_config(inputs, estimator):
52+
input_config = _Job._format_inputs_to_input_config(inputs)
53+
role = estimator.sagemaker_session.expand_role(estimator.role)
54+
output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key)
55+
resource_config = _Job._prepare_resource_config(estimator.train_instance_count,
56+
estimator.train_instance_type,
57+
estimator.train_volume_size)
58+
stopping_condition = _Job._prepare_stopping_condition(estimator.train_max_run)
59+
60+
return input_config, role, output_config, resource_config, stopping_condition
61+
62+
@staticmethod
63+
def _format_inputs_to_input_config(inputs):
64+
input_dict = {}
65+
if isinstance(inputs, string_types):
66+
input_dict['training'] = _Job._format_s3_uri_input(inputs)
67+
elif isinstance(inputs, s3_input):
68+
input_dict['training'] = inputs
69+
elif isinstance(inputs, dict):
70+
for k, v in inputs.items():
71+
input_dict[k] = _Job._format_s3_uri_input(v)
72+
else:
73+
raise ValueError('Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
74+
75+
channels = []
76+
for channel_name, channel_s3_input in input_dict.items():
77+
channel_config = channel_s3_input.config.copy()
78+
channel_config['ChannelName'] = channel_name
79+
channels.append(channel_config)
80+
return channels
81+
82+
@staticmethod
83+
def _format_s3_uri_input(input):
84+
if isinstance(input, str):
85+
if not input.startswith('s3://'):
86+
raise ValueError('Training input data must be a valid S3 URI and must start with "s3://"')
87+
return s3_input(input)
88+
if isinstance(input, s3_input):
89+
return input
90+
else:
91+
raise ValueError('Cannot format input {}. Expecting one of str or s3_input'.format(input))
92+
93+
@staticmethod
94+
def _prepare_output_config(s3_path, kms_key_id):
95+
config = {'S3OutputPath': s3_path}
96+
if kms_key_id is not None:
97+
config['KmsKeyId'] = kms_key_id
98+
return config
99+
100+
@staticmethod
101+
def _prepare_resource_config(instance_count, instance_type, volume_size):
102+
return {'InstanceCount': instance_count,
103+
'InstanceType': instance_type,
104+
'VolumeSizeInGB': volume_size}
105+
106+
@staticmethod
107+
def _prepare_stopping_condition(max_run):
108+
return {'MaxRuntimeInSeconds': max_run}
109+
110+
@property
111+
def name(self):
112+
return self.job_name

tests/unit/test_job.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import pytest
14+
from mock import Mock
15+
16+
from sagemaker.estimator import Estimator
17+
from sagemaker.job import _Job
18+
from sagemaker.session import s3_input
19+
20+
BUCKET_NAME = 's3://mybucket/train'
21+
S3_OUTPUT_PATH = 's3://bucket/prefix'
22+
INSTANCE_COUNT = 1
23+
INSTANCE_TYPE = 'c4.4xlarge'
24+
VOLUME_SIZE = 1
25+
MAX_RUNTIME = 1
26+
ROLE = 'DummyRole'
27+
IMAGE_NAME = 'fakeimage'
28+
JOB_NAME = 'fakejob'
29+
30+
31+
@pytest.fixture()
32+
def estimator(sagemaker_session):
33+
return Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, MAX_RUNTIME,
34+
output_path=S3_OUTPUT_PATH, sagemaker_session=sagemaker_session)
35+
36+
37+
@pytest.fixture()
38+
def job(sagemaker_session):
39+
return _Job(sagemaker_session, JOB_NAME)
40+
41+
42+
@pytest.fixture()
43+
def sagemaker_session():
44+
boto_mock = Mock(name='boto_session')
45+
mock_session = Mock(name='sagemaker_session', boto_session=boto_mock)
46+
mock_session.expand_role = Mock(name='expand_role', return_value=ROLE)
47+
48+
return mock_session
49+
50+
51+
def test_load_config(job, estimator):
52+
inputs = s3_input(BUCKET_NAME)
53+
54+
input_config, role, output_config, resource_config, stopping_condition = \
55+
job._load_config(inputs, estimator)
56+
57+
assert input_config[0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME
58+
assert role == ROLE
59+
assert output_config['S3OutputPath'] == S3_OUTPUT_PATH
60+
assert 'KmsKeyId' not in output_config
61+
assert resource_config['InstanceCount'] == INSTANCE_COUNT
62+
assert resource_config['InstanceType'] == INSTANCE_TYPE
63+
assert resource_config['VolumeSizeInGB'] == VOLUME_SIZE
64+
assert stopping_condition['MaxRuntimeInSeconds'] == MAX_RUNTIME
65+
66+
67+
def test_format_inputs_to_input_config_string(job):
68+
inputs = BUCKET_NAME
69+
70+
channels = job._format_inputs_to_input_config(inputs)
71+
72+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs
73+
74+
75+
def test_format_inputs_to_input_config_s3_input(job):
76+
inputs = s3_input(BUCKET_NAME)
77+
78+
channels = job._format_inputs_to_input_config(inputs)
79+
80+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.config['DataSource'][
81+
'S3DataSource']['S3Uri']
82+
83+
84+
def test_format_inputs_to_input_config_dict(job):
85+
inputs = {'train': BUCKET_NAME}
86+
87+
channels = job._format_inputs_to_input_config(inputs)
88+
89+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs['train']
90+
91+
92+
def test_format_inputs_to_input_config_exception(job):
93+
inputs = 1
94+
95+
with pytest.raises(ValueError):
96+
job._format_inputs_to_input_config(inputs)
97+
98+
99+
def test_format_s3_uri_input_string(job):
100+
inputs = BUCKET_NAME
101+
102+
s3_uri_input = job._format_s3_uri_input(inputs)
103+
104+
assert s3_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == inputs
105+
106+
107+
def test_format_s3_uri_input_string_exception(job):
108+
inputs = 'mybucket/train'
109+
110+
with pytest.raises(ValueError):
111+
job._format_s3_uri_input(inputs)
112+
113+
114+
def test_format_s3_uri_input(job):
115+
inputs = s3_input(BUCKET_NAME)
116+
117+
s3_uri_input = job._format_s3_uri_input(inputs)
118+
119+
assert s3_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == inputs.config[
120+
'DataSource']['S3DataSource']['S3Uri']
121+
122+
123+
def test_format_s3_uri_input_exception(job):
124+
inputs = 1
125+
126+
with pytest.raises(ValueError):
127+
job._format_s3_uri_input(inputs)
128+
129+
130+
def test_prepare_output_config(job):
131+
kms_key_id = 'kms_key'
132+
133+
config = job._prepare_output_config(BUCKET_NAME, kms_key_id)
134+
135+
assert config['S3OutputPath'] == BUCKET_NAME
136+
assert config['KmsKeyId'] == kms_key_id
137+
138+
139+
def test_prepare_output_config_kms_key_none(job):
140+
s3_path = BUCKET_NAME
141+
kms_key_id = None
142+
143+
config = job._prepare_output_config(s3_path, kms_key_id)
144+
145+
assert config['S3OutputPath'] == s3_path
146+
assert 'KmsKeyId' not in config
147+
148+
149+
def test_prepare_resource_config(job):
150+
resource_config = job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE)
151+
152+
assert resource_config['InstanceCount'] == INSTANCE_COUNT
153+
assert resource_config['InstanceType'] == INSTANCE_TYPE
154+
assert resource_config['VolumeSizeInGB'] == VOLUME_SIZE
155+
156+
157+
def test_prepare_stopping_condition(job):
158+
max_run = 1
159+
160+
stopping_condition = job._prepare_stopping_condition(max_run)
161+
162+
assert stopping_condition['MaxRuntimeInSeconds'] == max_run
163+
164+
165+
def test_name(job):
166+
assert job.name == JOB_NAME

0 commit comments

Comments
 (0)