Skip to content

Commit e6bcfc0

Browse files
author
Ignacio Quintero
committed
Refactor LocalSageMakerClient
local sagemakerclient was not very useful beyond creating a single estimator/endpoint. doing workflows such as training and later creating and endpoint was really awkward. These changes make it both resemble the API a bit more and allow persisting objects across LocalSessions. This is important because most of the time the SDK classes create sessions behind the scenes.
1 parent d67d9b8 commit e6bcfc0

File tree

5 files changed

+369
-134
lines changed

5 files changed

+369
-134
lines changed

src/sagemaker/local/entities.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import datetime
16+
import logging
17+
import time
18+
import urllib3
19+
20+
from sagemaker.local.image import _SageMakerContainer
21+
from sagemaker.utils import get_config_value
22+
23+
logger = logging.getLogger(__name__)
24+
logger.setLevel(logging.WARNING)
25+
26+
27+
class _LocalTrainingJob(object):
28+
29+
_STARTING = 'Starting'
30+
_TRAINING = 'Training'
31+
_COMPLETED = 'Completed'
32+
_states = ['Starting', 'Training', 'Completed']
33+
34+
def __init__(self, container):
35+
self.container = container
36+
self.model_artifacts = None
37+
self.state = 'created'
38+
self.start_time = None
39+
self.end_time = None
40+
41+
def start(self, input_data_config, hyperparameters):
42+
for channel in input_data_config:
43+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
44+
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
45+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
46+
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
47+
else:
48+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
49+
50+
if data_distribution != 'FullyReplicated':
51+
raise RuntimeError('DataDistribution: %s is not currently supported in Local Mode' %
52+
data_distribution)
53+
54+
self.start = datetime.datetime.now()
55+
self.state = self._TRAINING
56+
57+
self.model_artifacts = self.container.train(input_data_config, hyperparameters)
58+
self.end = datetime.datetime.now()
59+
self.state = self._COMPLETED
60+
61+
def describe(self):
62+
response = {
63+
'ResourceConfig': {
64+
'InstanceCount': self.container.instance_count
65+
},
66+
'TrainingJobStatus': self.state,
67+
'TrainingStartTime': self.start_time,
68+
'TrainingEndTime': self.end_time,
69+
'ModelArtifacts': {
70+
'S3ModelArtifacts': self.model_artifacts
71+
}
72+
}
73+
return response
74+
75+
76+
class _LocalModel(object):
77+
78+
def __init__(self, model_name, primary_container):
79+
self.model_name = model_name
80+
self.primary_container = primary_container
81+
self.creation_time = datetime.datetime.now()
82+
83+
def describe(self):
84+
response = {
85+
'ModelName': self.model_name,
86+
'CreationTime': self.creation_time,
87+
'ExecutionRoleArn': 'local:arn-does-not-matter',
88+
'ModelArn': 'local:arn-does-not-matter',
89+
'PrimaryContainer': self.primary_container
90+
}
91+
return response
92+
93+
94+
class _LocalEndpointConfig(object):
95+
96+
def __init__(self, config_name, production_variants):
97+
self.name = config_name
98+
self.production_variants = production_variants
99+
self.creation_time = datetime.datetime.now()
100+
101+
def describe(self):
102+
response = {
103+
'EndpointConfigName': self.name,
104+
'EndpointConfigArn': 'local:arn-does-not-matter',
105+
'CreationTime': self.creation_time,
106+
'ProductionVariants': self.production_variants
107+
}
108+
return response
109+
110+
111+
class _LocalEndpoint(object):
112+
113+
_CREATING = 'Creating'
114+
_IN_SERVICE = 'InService'
115+
_FAILED = 'Failed'
116+
117+
def __init__(self, endpoint_name, endpoint_config_name):
118+
from sagemaker.local import LocalSagemakerClient
119+
local_client = LocalSagemakerClient()
120+
121+
self.name = endpoint_name
122+
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
123+
self.production_variant = self.endpoint_config['ProductionVariants'][0]
124+
125+
model_name = self.production_variant['ModelName']
126+
self.primary_container = local_client.describe_model(model_name)['PrimaryContainer']
127+
128+
self.container = None
129+
self.create_time = None
130+
self.state = _LocalEndpoint._CREATING
131+
132+
def serve(self, sagemaker_session):
133+
image = self.primary_container['Image']
134+
instance_type = self.production_variant['InstanceType']
135+
instance_count = self.production_variant['InitialInstanceCount']
136+
137+
self.create_time = datetime.datetime.now()
138+
self.container = _SageMakerContainer(instance_type, instance_count, image, sagemaker_session)
139+
self.container.serve(self.primary_container['ModelDataUrl'], self.primary_container['Environment'])
140+
141+
i = 0
142+
http = urllib3.PoolManager()
143+
serving_port = get_config_value('local.serving_port', sagemaker_session.config) or 8080
144+
endpoint_url = 'http://localhost:%s/ping' % serving_port
145+
while True:
146+
i += 1
147+
if i >= 10:
148+
self.state = _LocalEndpoint._FAILED
149+
raise RuntimeError('Giving up, endpoint: %s didn\'t launch correctly' % self.name)
150+
151+
logger.info('Checking if endpoint is up, attempt: %s' % i)
152+
try:
153+
r = http.request('GET', endpoint_url)
154+
if r.status != 200:
155+
logger.info('Container still not up, got: %s' % r.status)
156+
else:
157+
# the container is running and it passed the healthcheck status is now InService
158+
self.state = _LocalEndpoint._IN_SERVICE
159+
return
160+
except urllib3.exceptions.RequestError:
161+
logger.info('Container still not up')
162+
163+
time.sleep(1)
164+
165+
def stop(self):
166+
if self.container:
167+
self.container.stop_serving()
168+
169+
def describe(self):
170+
response = {
171+
'EndpointConfigName': self.endpoint_config['EndpointConfigName'],
172+
'CreationTime': self.create_time,
173+
'ProductionVariants': self.endpoint_config['ProductionVariants'],
174+
'EndpointName': self.name,
175+
'EndpointArn': 'local:arn-does-not-matter',
176+
'EndpointStatus': self.state
177+
}
178+
return response

src/sagemaker/local/image.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def train(self, input_data_config, hyperparameters):
134134
print('===== Job Complete =====')
135135
return s3_artifacts
136136

137-
def serve(self, primary_container):
137+
def serve(self, model_dir, environment):
138138
"""Host a local endpoint using docker-compose.
139139
Args:
140140
primary_container (dict): dictionary containing the container runtime settings
@@ -148,13 +148,12 @@ def serve(self, primary_container):
148148
self.container_root = self._create_tmp_folder()
149149
logger.info('creating hosting dir in {}'.format(self.container_root))
150150

151-
model_dir = primary_container['ModelDataUrl']
152151
volumes = self._prepare_serving_volumes(model_dir)
153-
env_vars = ['{}={}'.format(k, v) for k, v in primary_container['Environment'].items()]
152+
env_vars = ['{}={}'.format(k, v) for k, v in environment.items()]
154153

155154
# If the user script was passed as a file:// mount it to the container.
156-
if sagemaker.estimator.DIR_PARAM_NAME.upper() in primary_container['Environment']:
157-
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
155+
if sagemaker.estimator.DIR_PARAM_NAME.upper() in environment:
156+
script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()]
158157
parsed_uri = urlparse(script_dir)
159158
if parsed_uri.scheme == 'file':
160159
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))

src/sagemaker/local/local_session.py

Lines changed: 62 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import datetime
1615
import logging
1716
import platform
18-
import time
1917

2018
import boto3
2119
import urllib3
2220
from botocore.exceptions import ClientError
2321

2422
from sagemaker.local.image import _SageMakerContainer
23+
from sagemaker.local.entities import _LocalEndpointConfig, _LocalEndpoint, _LocalModel, _LocalTrainingJob
2524
from sagemaker.session import Session
2625
from sagemaker.utils import get_config_value
2726

@@ -37,42 +36,44 @@ class LocalSagemakerClient(object):
3736
3837
Implements the methods with the same signature as the boto SageMakerClient.
3938
"""
39+
40+
_training_jobs = {}
41+
_models = {}
42+
_endpoint_configs = {}
43+
_endpoints = {}
44+
4045
def __init__(self, sagemaker_session=None):
4146
"""Initialize a LocalSageMakerClient.
4247
4348
Args:
4449
sagemaker_session (sagemaker.session.Session): a session to use to read configurations
4550
from, and use its boto client.
4651
"""
47-
self.train_container = None
4852
self.serve_container = None
4953
self.sagemaker_session = sagemaker_session or LocalSession()
5054
self.s3_model_artifacts = None
51-
self.model_name = None
52-
self.primary_container = None
53-
self.role_arn = None
5455
self.created_endpoint = False
5556

56-
def create_training_job(self, TrainingJobName, AlgorithmSpecification, RoleArn, InputDataConfig, OutputDataConfig,
57-
ResourceConfig, StoppingCondition, HyperParameters, Tags=None):
58-
59-
self.train_container = _SageMakerContainer(ResourceConfig['InstanceType'], ResourceConfig['InstanceCount'],
60-
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
61-
62-
for channel in InputDataConfig:
63-
64-
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
65-
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
66-
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
67-
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
68-
else:
69-
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
57+
def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputDataConfig, OutputDataConfig,
58+
ResourceConfig, HyperParameters, *args, **kwargs):
59+
"""
60+
Create a training job in Local Mode
61+
Args:
62+
TrainingJobName (str): local training job name.
63+
AlgorithmSpecification (dict): Identifies the training algorithm to use.
64+
InputDataConfig (dict): Describes the training dataset and the location where it is stored.
65+
OutputDataConfig (dict): Identifies the location where you want to save the results of model training.
66+
ResourceConfig (dict): Identifies the resources to use for local model traininig.
67+
HyperParameters (dict): Specify these algorithm-specific parameters to influence the quality of the final
68+
model.
69+
"""
7070

71-
if data_distribution != 'FullyReplicated':
72-
raise RuntimeError("DataDistribution: %s is not currently supported in Local Mode" %
73-
data_distribution)
71+
container = _SageMakerContainer(ResourceConfig['InstanceType'], ResourceConfig['InstanceCount'],
72+
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
73+
train_job = _LocalTrainingJob(container)
74+
train_job.start(InputDataConfig, HyperParameters)
7475

75-
self.s3_model_artifacts = self.train_container.train(InputDataConfig, HyperParameters)
76+
LocalSagemakerClient._training_jobs[TrainingJobName] = train_job
7677

7778
def describe_training_job(self, TrainingJobName):
7879
"""Describe a local training job.
@@ -83,63 +84,55 @@ def describe_training_job(self, TrainingJobName):
8384
Returns: (dict) DescribeTrainingJob Response.
8485
8586
"""
86-
response = {'ResourceConfig': {'InstanceCount': self.train_container.instance_count},
87-
'TrainingJobStatus': 'Completed',
88-
'TrainingStartTime': datetime.datetime.now(),
89-
'TrainingEndTime': datetime.datetime.now(),
90-
'ModelArtifacts': {'S3ModelArtifacts': self.s3_model_artifacts}
91-
}
92-
return response
93-
94-
def create_model(self, ModelName, PrimaryContainer, ExecutionRoleArn):
95-
self.model_name = ModelName
96-
self.primary_container = PrimaryContainer
97-
self.role_arn = ExecutionRoleArn
87+
if TrainingJobName not in LocalSagemakerClient._training_jobs:
88+
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local training job'}}
89+
raise ClientError(error_response, 'describe_training_job')
90+
else:
91+
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()
92+
93+
def create_model(self, ModelName, PrimaryContainer, *args, **kwargs):
94+
"""Create a Local Model Object
95+
96+
Args:
97+
ModelName (str): the Model Name
98+
PrimaryContainer (dict): a SageMaker primary container definition
99+
"""
100+
LocalSagemakerClient._models[ModelName] = _LocalModel(ModelName, PrimaryContainer)
101+
102+
def describe_model(self, ModelName):
103+
if ModelName not in LocalSagemakerClient._models:
104+
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local model'}}
105+
raise ClientError(error_response, 'describe_model')
106+
else:
107+
return LocalSagemakerClient._models[ModelName].describe()
98108

99109
def describe_endpoint_config(self, EndpointConfigName):
100-
if self.created_endpoint:
101-
return True
110+
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
111+
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()
102112
else:
103-
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find endpoint'}}
113+
error_response = {'Error': {
114+
'Code': 'ValidationException', 'Message': 'Could not find local endpoint config'}}
104115
raise ClientError(error_response, 'describe_endpoint_config')
105116

106117
def create_endpoint_config(self, EndpointConfigName, ProductionVariants):
107-
self.variants = ProductionVariants
118+
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
119+
EndpointConfigName, ProductionVariants)
108120

109121
def describe_endpoint(self, EndpointName):
110-
return {'EndpointStatus': 'InService'}
122+
if EndpointName not in LocalSagemakerClient._endpoints:
123+
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local endpoint'}}
124+
raise ClientError(error_response, 'describe_endpoint')
125+
else:
126+
return LocalSagemakerClient._endpoints[EndpointName].describe()
111127

112128
def create_endpoint(self, EndpointName, EndpointConfigName):
113-
instance_type = self.variants[0]['InstanceType']
114-
instance_count = self.variants[0]['InitialInstanceCount']
115-
self.serve_container = _SageMakerContainer(instance_type, instance_count,
116-
self.primary_container['Image'], self.sagemaker_session)
117-
self.serve_container.serve(self.primary_container)
118-
self.created_endpoint = True
119-
120-
i = 0
121-
http = urllib3.PoolManager()
122-
serving_port = get_config_value('local.serving_port', self.sagemaker_session.config) or 8080
123-
endpoint_url = "http://localhost:%s/ping" % serving_port
124-
while True:
125-
i += 1
126-
if i >= 10:
127-
raise RuntimeError("Giving up, endpoint: %s didn't launch correctly" % EndpointName)
128-
129-
logger.info("Checking if endpoint is up, attempt: %s" % i)
130-
try:
131-
r = http.request('GET', endpoint_url)
132-
if r.status != 200:
133-
logger.info("Container still not up, got: %s" % r.status)
134-
else:
135-
return
136-
except urllib3.exceptions.RequestError:
137-
logger.info("Container still not up")
138-
139-
time.sleep(1)
129+
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName)
130+
LocalSagemakerClient._endpoints[EndpointName] = endpoint
131+
endpoint.serve(self.sagemaker_session)
140132

141133
def delete_endpoint(self, EndpointName):
142-
self.serve_container.stop_serving()
134+
if EndpointName in LocalSagemakerClient._endpoints:
135+
LocalSagemakerClient._endpoints[EndpointName].stop()
143136

144137

145138
class LocalSagemakerRuntimeClient(object):

0 commit comments

Comments
 (0)