12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
- import datetime
16
15
import logging
17
16
import platform
18
- import time
19
17
20
18
import boto3
21
19
import urllib3
22
20
from botocore .exceptions import ClientError
23
21
24
22
from sagemaker .local .image import _SageMakerContainer
23
+ from sagemaker .local .entities import _LocalEndpointConfig , _LocalEndpoint , _LocalModel , _LocalTrainingJob
25
24
from sagemaker .session import Session
26
25
from sagemaker .utils import get_config_value
27
26
@@ -37,42 +36,44 @@ class LocalSagemakerClient(object):
37
36
38
37
Implements the methods with the same signature as the boto SageMakerClient.
39
38
"""
39
+
40
+ _training_jobs = {}
41
+ _models = {}
42
+ _endpoint_configs = {}
43
+ _endpoints = {}
44
+
40
45
def __init__ (self , sagemaker_session = None ):
41
46
"""Initialize a LocalSageMakerClient.
42
47
43
48
Args:
44
49
sagemaker_session (sagemaker.session.Session): a session to use to read configurations
45
50
from, and use its boto client.
46
51
"""
47
- self .train_container = None
48
52
self .serve_container = None
49
53
self .sagemaker_session = sagemaker_session or LocalSession ()
50
54
self .s3_model_artifacts = None
51
- self .model_name = None
52
- self .primary_container = None
53
- self .role_arn = None
54
55
self .created_endpoint = False
55
56
56
- def create_training_job (self , TrainingJobName , AlgorithmSpecification , RoleArn , InputDataConfig , OutputDataConfig ,
57
- ResourceConfig , StoppingCondition , HyperParameters , Tags = None ):
58
-
59
- self .train_container = _SageMakerContainer (ResourceConfig ['InstanceType' ], ResourceConfig ['InstanceCount' ],
60
- AlgorithmSpecification ['TrainingImage' ], self .sagemaker_session )
61
-
62
- for channel in InputDataConfig :
63
-
64
- if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
65
- data_distribution = channel ['DataSource' ]['S3DataSource' ]['S3DataDistributionType' ]
66
- elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
67
- data_distribution = channel ['DataSource' ]['FileDataSource' ]['FileDataDistributionType' ]
68
- else :
69
- raise ValueError ('Need channel[\' DataSource\' ] to have [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
57
+ def create_training_job (self , TrainingJobName , AlgorithmSpecification , InputDataConfig , OutputDataConfig ,
58
+ ResourceConfig , HyperParameters , * args , ** kwargs ):
59
+ """
60
+ Create a training job in Local Mode
61
+ Args:
62
+ TrainingJobName (str): local training job name.
63
+ AlgorithmSpecification (dict): Identifies the training algorithm to use.
64
+ InputDataConfig (dict): Describes the training dataset and the location where it is stored.
65
+ OutputDataConfig (dict): Identifies the location where you want to save the results of model training.
66
+ ResourceConfig (dict): Identifies the resources to use for local model traininig.
67
+ HyperParameters (dict): Specify these algorithm-specific parameters to influence the quality of the final
68
+ model.
69
+ """
70
70
71
- if data_distribution != 'FullyReplicated' :
72
- raise RuntimeError ("DataDistribution: %s is not currently supported in Local Mode" %
73
- data_distribution )
71
+ container = _SageMakerContainer (ResourceConfig ['InstanceType' ], ResourceConfig ['InstanceCount' ],
72
+ AlgorithmSpecification ['TrainingImage' ], self .sagemaker_session )
73
+ train_job = _LocalTrainingJob (container )
74
+ train_job .start (InputDataConfig , HyperParameters )
74
75
75
- self . s3_model_artifacts = self . train_container . train ( InputDataConfig , HyperParameters )
76
+ LocalSagemakerClient . _training_jobs [ TrainingJobName ] = train_job
76
77
77
78
def describe_training_job (self , TrainingJobName ):
78
79
"""Describe a local training job.
@@ -83,63 +84,55 @@ def describe_training_job(self, TrainingJobName):
83
84
Returns: (dict) DescribeTrainingJob Response.
84
85
85
86
"""
86
- response = {'ResourceConfig' : {'InstanceCount' : self .train_container .instance_count },
87
- 'TrainingJobStatus' : 'Completed' ,
88
- 'TrainingStartTime' : datetime .datetime .now (),
89
- 'TrainingEndTime' : datetime .datetime .now (),
90
- 'ModelArtifacts' : {'S3ModelArtifacts' : self .s3_model_artifacts }
91
- }
92
- return response
93
-
94
- def create_model (self , ModelName , PrimaryContainer , ExecutionRoleArn ):
95
- self .model_name = ModelName
96
- self .primary_container = PrimaryContainer
97
- self .role_arn = ExecutionRoleArn
87
+ if TrainingJobName not in LocalSagemakerClient ._training_jobs :
88
+ error_response = {'Error' : {'Code' : 'ValidationException' , 'Message' : 'Could not find local training job' }}
89
+ raise ClientError (error_response , 'describe_training_job' )
90
+ else :
91
+ return LocalSagemakerClient ._training_jobs [TrainingJobName ].describe ()
92
+
93
+ def create_model (self , ModelName , PrimaryContainer , * args , ** kwargs ):
94
+ """Create a Local Model Object
95
+
96
+ Args:
97
+ ModelName (str): the Model Name
98
+ PrimaryContainer (dict): a SageMaker primary container definition
99
+ """
100
+ LocalSagemakerClient ._models [ModelName ] = _LocalModel (ModelName , PrimaryContainer )
101
+
102
+ def describe_model (self , ModelName ):
103
+ if ModelName not in LocalSagemakerClient ._models :
104
+ error_response = {'Error' : {'Code' : 'ValidationException' , 'Message' : 'Could not find local model' }}
105
+ raise ClientError (error_response , 'describe_model' )
106
+ else :
107
+ return LocalSagemakerClient ._models [ModelName ].describe ()
98
108
99
109
def describe_endpoint_config (self , EndpointConfigName ):
100
- if self . created_endpoint :
101
- return True
110
+ if EndpointConfigName in LocalSagemakerClient . _endpoint_configs :
111
+ return LocalSagemakerClient . _endpoint_configs [ EndpointConfigName ]. describe ()
102
112
else :
103
- error_response = {'Error' : {'Code' : 'ValidationException' , 'Message' : 'Could not find endpoint' }}
113
+ error_response = {'Error' : {
114
+ 'Code' : 'ValidationException' , 'Message' : 'Could not find local endpoint config' }}
104
115
raise ClientError (error_response , 'describe_endpoint_config' )
105
116
106
117
def create_endpoint_config (self , EndpointConfigName , ProductionVariants ):
107
- self .variants = ProductionVariants
118
+ LocalSagemakerClient ._endpoint_configs [EndpointConfigName ] = _LocalEndpointConfig (
119
+ EndpointConfigName , ProductionVariants )
108
120
109
121
def describe_endpoint (self , EndpointName ):
110
- return {'EndpointStatus' : 'InService' }
122
+ if EndpointName not in LocalSagemakerClient ._endpoints :
123
+ error_response = {'Error' : {'Code' : 'ValidationException' , 'Message' : 'Could not find local endpoint' }}
124
+ raise ClientError (error_response , 'describe_endpoint' )
125
+ else :
126
+ return LocalSagemakerClient ._endpoints [EndpointName ].describe ()
111
127
112
128
def create_endpoint (self , EndpointName , EndpointConfigName ):
113
- instance_type = self .variants [0 ]['InstanceType' ]
114
- instance_count = self .variants [0 ]['InitialInstanceCount' ]
115
- self .serve_container = _SageMakerContainer (instance_type , instance_count ,
116
- self .primary_container ['Image' ], self .sagemaker_session )
117
- self .serve_container .serve (self .primary_container )
118
- self .created_endpoint = True
119
-
120
- i = 0
121
- http = urllib3 .PoolManager ()
122
- serving_port = get_config_value ('local.serving_port' , self .sagemaker_session .config ) or 8080
123
- endpoint_url = "http://localhost:%s/ping" % serving_port
124
- while True :
125
- i += 1
126
- if i >= 10 :
127
- raise RuntimeError ("Giving up, endpoint: %s didn't launch correctly" % EndpointName )
128
-
129
- logger .info ("Checking if endpoint is up, attempt: %s" % i )
130
- try :
131
- r = http .request ('GET' , endpoint_url )
132
- if r .status != 200 :
133
- logger .info ("Container still not up, got: %s" % r .status )
134
- else :
135
- return
136
- except urllib3 .exceptions .RequestError :
137
- logger .info ("Container still not up" )
138
-
139
- time .sleep (1 )
129
+ endpoint = _LocalEndpoint (EndpointName , EndpointConfigName )
130
+ LocalSagemakerClient ._endpoints [EndpointName ] = endpoint
131
+ endpoint .serve (self .sagemaker_session )
140
132
141
133
def delete_endpoint (self , EndpointName ):
142
- self .serve_container .stop_serving ()
134
+ if EndpointName in LocalSagemakerClient ._endpoints :
135
+ LocalSagemakerClient ._endpoints [EndpointName ].stop ()
143
136
144
137
145
138
class LocalSagemakerRuntimeClient (object ):
0 commit comments