Skip to content

Commit 230f3f6

Browse files
iquinteroPiali Das
authored and
Piali Das
committed
Refactor LocalSageMakerClient (aws#375)
Local sagemakerclient was not very useful beyond creating a single estimator/endpoint. doing workflows such as training and later creating and endpoint was really awkward. These changes make it both resemble the API a bit more and allow persisting objects across LocalSessions. This is important because most of the time the SDK classes create sessions behind the scenes. includes 2 bug fixes: - Fix Hyperparameters being mandatory in local mode - Fix serving container timeout time to match SageMaker
1 parent 4d85631 commit 230f3f6

File tree

6 files changed

+394
-135
lines changed

6 files changed

+394
-135
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.10.2dev
6+
=========
7+
* bug-fix: Setting health check timeout limit on local mode to 30s
8+
* bug-fix: Make Hyperparameters in local mode optional.
9+
510
1.10.1
611
======
712

src/sagemaker/local/entities.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import datetime
16+
import logging
17+
import time
18+
import urllib3
19+
20+
from sagemaker.local.image import _SageMakerContainer
21+
from sagemaker.utils import get_config_value
22+
23+
logger = logging.getLogger(__name__)
24+
logger.setLevel(logging.WARNING)
25+
26+
_UNUSED_ARN = 'local:arn-does-not-matter'
27+
HEALTH_CHECK_TIMEOUT_LIMIT = 30
28+
29+
30+
class _LocalTrainingJob(object):
31+
32+
_STARTING = 'Starting'
33+
_TRAINING = 'Training'
34+
_COMPLETED = 'Completed'
35+
_states = ['Starting', 'Training', 'Completed']
36+
37+
def __init__(self, container):
38+
self.container = container
39+
self.model_artifacts = None
40+
self.state = 'created'
41+
self.start_time = None
42+
self.end_time = None
43+
44+
def start(self, input_data_config, hyperparameters):
45+
for channel in input_data_config:
46+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
47+
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
48+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
49+
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
50+
else:
51+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
52+
53+
if data_distribution != 'FullyReplicated':
54+
raise RuntimeError('DataDistribution: %s is not currently supported in Local Mode' %
55+
data_distribution)
56+
57+
self.start = datetime.datetime.now()
58+
self.state = self._TRAINING
59+
60+
self.model_artifacts = self.container.train(input_data_config, hyperparameters)
61+
self.end = datetime.datetime.now()
62+
self.state = self._COMPLETED
63+
64+
def describe(self):
65+
response = {
66+
'ResourceConfig': {
67+
'InstanceCount': self.container.instance_count
68+
},
69+
'TrainingJobStatus': self.state,
70+
'TrainingStartTime': self.start_time,
71+
'TrainingEndTime': self.end_time,
72+
'ModelArtifacts': {
73+
'S3ModelArtifacts': self.model_artifacts
74+
}
75+
}
76+
return response
77+
78+
79+
class _LocalModel(object):
80+
81+
def __init__(self, model_name, primary_container):
82+
self.model_name = model_name
83+
self.primary_container = primary_container
84+
self.creation_time = datetime.datetime.now()
85+
86+
def describe(self):
87+
response = {
88+
'ModelName': self.model_name,
89+
'CreationTime': self.creation_time,
90+
'ExecutionRoleArn': _UNUSED_ARN,
91+
'ModelArn': _UNUSED_ARN,
92+
'PrimaryContainer': self.primary_container
93+
}
94+
return response
95+
96+
97+
class _LocalEndpointConfig(object):
98+
99+
def __init__(self, config_name, production_variants):
100+
self.name = config_name
101+
self.production_variants = production_variants
102+
self.creation_time = datetime.datetime.now()
103+
104+
def describe(self):
105+
response = {
106+
'EndpointConfigName': self.name,
107+
'EndpointConfigArn': _UNUSED_ARN,
108+
'CreationTime': self.creation_time,
109+
'ProductionVariants': self.production_variants
110+
}
111+
return response
112+
113+
114+
class _LocalEndpoint(object):
115+
116+
_CREATING = 'Creating'
117+
_IN_SERVICE = 'InService'
118+
_FAILED = 'Failed'
119+
120+
def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
121+
# runtime import since there is a cyclic dependency between entities and local_session
122+
from sagemaker.local import LocalSession
123+
self.local_session = local_session or LocalSession()
124+
local_client = self.local_session.sagemaker_client
125+
126+
self.name = endpoint_name
127+
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
128+
self.production_variant = self.endpoint_config['ProductionVariants'][0]
129+
130+
model_name = self.production_variant['ModelName']
131+
self.primary_container = local_client.describe_model(model_name)['PrimaryContainer']
132+
133+
self.container = None
134+
self.create_time = None
135+
self.state = _LocalEndpoint._CREATING
136+
137+
def serve(self):
138+
image = self.primary_container['Image']
139+
instance_type = self.production_variant['InstanceType']
140+
instance_count = self.production_variant['InitialInstanceCount']
141+
142+
self.create_time = datetime.datetime.now()
143+
self.container = _SageMakerContainer(instance_type, instance_count, image, self.local_session)
144+
self.container.serve(self.primary_container['ModelDataUrl'], self.primary_container['Environment'])
145+
146+
i = 0
147+
http = urllib3.PoolManager()
148+
serving_port = get_config_value('local.serving_port', self.local_session.config) or 8080
149+
endpoint_url = 'http://localhost:%s/ping' % serving_port
150+
while True:
151+
i += 1
152+
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:
153+
self.state = _LocalEndpoint._FAILED
154+
raise RuntimeError('Giving up, endpoint: %s didn\'t launch correctly' % self.name)
155+
156+
logger.info('Checking if endpoint is up, attempt: %s' % i)
157+
try:
158+
r = http.request('GET', endpoint_url)
159+
if r.status != 200:
160+
logger.info('Container still not up, got: %s' % r.status)
161+
else:
162+
# the container is running and it passed the healthcheck status is now InService
163+
self.state = _LocalEndpoint._IN_SERVICE
164+
return
165+
except urllib3.exceptions.RequestError:
166+
logger.info('Container still not up')
167+
168+
time.sleep(1)
169+
170+
def stop(self):
171+
if self.container:
172+
self.container.stop_serving()
173+
174+
def describe(self):
175+
response = {
176+
'EndpointConfigName': self.endpoint_config['EndpointConfigName'],
177+
'CreationTime': self.create_time,
178+
'ProductionVariants': self.endpoint_config['ProductionVariants'],
179+
'EndpointName': self.name,
180+
'EndpointArn': _UNUSED_ARN,
181+
'EndpointStatus': self.state
182+
}
183+
return response

src/sagemaker/local/image.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def train(self, input_data_config, hyperparameters):
134134
print('===== Job Complete =====')
135135
return s3_artifacts
136136

137-
def serve(self, primary_container):
137+
def serve(self, model_dir, environment):
138138
"""Host a local endpoint using docker-compose.
139139
Args:
140140
primary_container (dict): dictionary containing the container runtime settings
@@ -148,13 +148,12 @@ def serve(self, primary_container):
148148
self.container_root = self._create_tmp_folder()
149149
logger.info('creating hosting dir in {}'.format(self.container_root))
150150

151-
model_dir = primary_container['ModelDataUrl']
152151
volumes = self._prepare_serving_volumes(model_dir)
153-
env_vars = ['{}={}'.format(k, v) for k, v in primary_container['Environment'].items()]
152+
env_vars = ['{}={}'.format(k, v) for k, v in environment.items()]
154153

155154
# If the user script was passed as a file:// mount it to the container.
156-
if sagemaker.estimator.DIR_PARAM_NAME.upper() in primary_container['Environment']:
157-
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
155+
if sagemaker.estimator.DIR_PARAM_NAME.upper() in environment:
156+
script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()]
158157
parsed_uri = urlparse(script_dir)
159158
if parsed_uri.scheme == 'file':
160159
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))

0 commit comments

Comments
 (0)