Skip to content

Commit fda0e0e

Browse files
authored
Allow Local Mode to work with a local training script. (#178)
This change works when ~/.sagemaker/config.yaml has local: local_code: True It depends on the container image supporting a local training script instead of an s3 location.
1 parent c1f1ab9 commit fda0e0e

29 files changed

+400
-134
lines changed

src/sagemaker/estimator.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515
import json
1616
import logging
17+
import os
1718
from abc import ABCMeta
1819
from abc import abstractmethod
1920
from six import with_metaclass, string_types
2021

21-
from sagemaker.fw_utils import tar_and_upload_dir
22-
from sagemaker.fw_utils import parse_s3_url
23-
from sagemaker.fw_utils import UploadedCode
24-
from sagemaker.local.local_session import LocalSession, file_input
22+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
23+
from sagemaker.local import LocalSession, file_input
2524

2625
from sagemaker.model import Model
2726
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
@@ -30,7 +29,7 @@
3029
from sagemaker.predictor import RealTimePredictor
3130
from sagemaker.session import Session
3231
from sagemaker.session import s3_input
33-
from sagemaker.utils import base_name_from_image, name_from_base
32+
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
3433

3534

3635
class EstimatorBase(with_metaclass(ABCMeta, object)):
@@ -83,13 +82,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
8382
self.input_mode = input_mode
8483

8584
if self.train_instance_type in ('local', 'local_gpu'):
86-
self.local_mode = True
8785
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
8886
raise RuntimeError("Distributed Training in Local GPU is not supported")
89-
9087
self.sagemaker_session = sagemaker_session or LocalSession()
9188
else:
92-
self.local_mode = False
9389
self.sagemaker_session = sagemaker_session or Session()
9490

9591
self.base_job_name = base_job_name
@@ -158,9 +154,14 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
158154
base_name = self.base_job_name or base_name_from_image(self.train_image())
159155
self._current_job_name = name_from_base(base_name)
160156

161-
# if output_path was specified we use it otherwise initialize here
157+
# if output_path was specified we use it otherwise initialize here.
158+
# For Local Mode with local_code=True we don't need an explicit output_path
162159
if self.output_path is None:
163-
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
160+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
161+
if self.sagemaker_session.local_mode and local_code:
162+
self.output_path = ''
163+
else:
164+
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
164165

165166
self.latest_training_job = _TrainingJob.start_new(self, inputs)
166167
if wait:
@@ -323,7 +324,7 @@ def start_new(cls, estimator, inputs):
323324
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
324325
"""
325326

326-
local_mode = estimator.local_mode
327+
local_mode = estimator.sagemaker_session.local_mode
327328

328329
# Allow file:// input only in local mode
329330
if isinstance(inputs, str) and inputs.startswith('file://'):
@@ -604,27 +605,54 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
604605
base_name = self.base_job_name or base_name_from_image(self.train_image())
605606
self._current_job_name = name_from_base(base_name)
606607

608+
# validate source dir will raise a ValueError if there is something wrong with the
609+
# source directory. We are intentionally not handling it because this is a critical error.
610+
if self.source_dir and not self.source_dir.lower().startswith('s3://'):
611+
validate_source_dir(self.entry_point, self.source_dir)
612+
613+
# if we are in local mode with local_code=True. We want the container to just
614+
# mount the source dir instead of uploading to S3.
615+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
616+
if self.sagemaker_session.local_mode and local_code:
617+
# if there is no source dir, use the directory containing the entry point.
618+
if self.source_dir is None:
619+
self.source_dir = os.path.dirname(self.entry_point)
620+
self.entry_point = os.path.basename(self.entry_point)
621+
622+
code_dir = 'file://' + self.source_dir
623+
script = self.entry_point
624+
else:
625+
self.uploaded_code = self._stage_user_code_in_s3()
626+
code_dir = self.uploaded_code.s3_prefix
627+
script = self.uploaded_code.script_name
628+
629+
# Modify hyperparameters in-place to point to the right code directory and script URIs
630+
self._hyperparameters[DIR_PARAM_NAME] = code_dir
631+
self._hyperparameters[SCRIPT_PARAM_NAME] = script
632+
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
633+
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
634+
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
635+
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
636+
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
637+
638+
def _stage_user_code_in_s3(self):
639+
""" Upload the user training script to s3 and return the location.
640+
641+
Returns: s3 uri
642+
643+
"""
607644
if self.code_location is None:
608645
code_bucket = self.sagemaker_session.default_bucket()
609646
code_s3_prefix = '{}/source'.format(self._current_job_name)
610647
else:
611648
code_bucket, key_prefix = parse_s3_url(self.code_location)
612649
code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)
613650

614-
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
615-
bucket=code_bucket,
616-
s3_key_prefix=code_s3_prefix,
617-
script=self.entry_point,
618-
directory=self.source_dir)
619-
620-
# Modify hyperparameters in-place to add the URLs to the uploaded code.
621-
self._hyperparameters[DIR_PARAM_NAME] = self.uploaded_code.s3_prefix
622-
self._hyperparameters[SCRIPT_PARAM_NAME] = self.uploaded_code.script_name
623-
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
624-
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
625-
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
626-
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_session.region_name
627-
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
651+
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
652+
bucket=code_bucket,
653+
s3_key_prefix=code_s3_prefix,
654+
script=self.entry_point,
655+
directory=self.source_dir)
628656

629657
def hyperparameters(self):
630658
"""Return the hyperparameters as a dictionary to use for training.

src/sagemaker/fw_utils.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
6868
.format(account, region, framework, tag)
6969

7070

71+
def validate_source_dir(script, directory):
72+
"""Validate that the source directory exists and it contains the user script
73+
74+
Args:
75+
script (str): Script filename.
76+
directory (str): Directory containing the source file.
77+
78+
Raises:
79+
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``.
80+
"""
81+
if directory:
82+
if not os.path.isfile(os.path.join(directory, script)):
83+
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory))
84+
85+
return True
86+
87+
7188
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
7289
"""Pack and upload source files to S3 only if directory is empty or local.
7390
@@ -83,21 +100,13 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
83100
84101
Returns:
85102
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
86-
87-
Raises:
88-
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``.
89103
"""
90104
if directory:
91105
if directory.lower().startswith("s3://"):
92106
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
93-
if not os.path.exists(directory):
94-
raise ValueError('"{}" does not exist.'.format(directory))
95-
if not os.path.isdir(directory):
96-
raise ValueError('"{}" is not a directory.'.format(directory))
97-
if script not in os.listdir(directory):
98-
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory))
99-
script_name = script
100-
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
107+
else:
108+
script_name = script
109+
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
101110
else:
102111
# If no directory is specified, the script parameter needs to be a valid relative path.
103112
os.path.exists(script)

src/sagemaker/local/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from .local_session import (file_input, LocalSession, LocalSagemakerRuntimeClient,
14+
LocalSagemakerClient)
15+
16+
__all__ = [file_input, LocalSession, LocalSagemakerClient, LocalSagemakerRuntimeClient]

src/sagemaker/local/image.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
import yaml
3131

32+
import sagemaker
33+
from sagemaker.utils import get_config_value
34+
3235
CONTAINER_PREFIX = "algo"
3336
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
3437

@@ -68,11 +71,6 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
6871
self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)]
6972
self.container_root = None
7073
self.container = None
71-
# set the local config. This is optional and will use reasonable defaults
72-
# if not present.
73-
self.local_config = None
74-
if self.sagemaker_session.config and 'local' in self.sagemaker_session.config:
75-
self.local_config = self.sagemaker_session.config['local']
7674

7775
def train(self, input_data_config, hyperparameters):
7876
"""Run a training job locally using docker-compose.
@@ -85,6 +83,10 @@ def train(self, input_data_config, hyperparameters):
8583
"""
8684
self.container_root = self._create_tmp_folder()
8785
os.mkdir(os.path.join(self.container_root, 'output'))
86+
# A shared directory for all the containers. It is only mounted if the training script is
87+
# Local.
88+
shared_dir = os.path.join(self.container_root, 'shared')
89+
os.mkdir(shared_dir)
8890

8991
data_dir = self._create_tmp_folder()
9092
volumes = []
@@ -116,6 +118,14 @@ def train(self, input_data_config, hyperparameters):
116118
else:
117119
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))
118120

121+
# If the training script directory is a local directory, mount it to the container.
122+
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
123+
parsed_uri = urlparse(training_dir)
124+
if parsed_uri.scheme == 'file':
125+
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
126+
# Also mount a directory that all the containers can access.
127+
volumes.append(_Volume(shared_dir, '/opt/ml/shared'))
128+
119129
# Create the configuration files for each container that we will create
120130
# Each container will map the additional local volumes (if any).
121131
for host in self.hosts:
@@ -135,6 +145,7 @@ def train(self, input_data_config, hyperparameters):
135145
# lots of data downloaded from S3. This doesn't delete any local
136146
# data that was just mounted to the container.
137147
_delete_tree(data_dir)
148+
_delete_tree(shared_dir)
138149
# Also free the container config files.
139150
for host in self.hosts:
140151
container_config_path = os.path.join(self.container_root, host)
@@ -171,7 +182,16 @@ def serve(self, primary_container):
171182

172183
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
173184

174-
self._generate_compose_file('serve', additional_env_vars=env_vars)
185+
# If the user script was passed as a file:// mount it to the container.
186+
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
187+
parsed_uri = urlparse(script_dir)
188+
volumes = []
189+
if parsed_uri.scheme == 'file':
190+
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
191+
192+
self._generate_compose_file('serve',
193+
additional_env_vars=env_vars,
194+
additional_volumes=volumes)
175195
compose_command = self._compose()
176196
self.container = _HostingContainer(compose_command)
177197
self.container.up()
@@ -366,8 +386,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
366386
}
367387
}
368388

369-
serving_port = 8080 if self.local_config is None else self.local_config.get('serving_port', 8080)
370389
if command == 'serve':
390+
serving_port = get_config_value('local.serving_port',
391+
self.sagemaker_session.config) or 8080
371392
host_config.update({
372393
'ports': [
373394
'%s:8080' % serving_port
@@ -377,9 +398,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
377398
return host_config
378399

379400
def _create_tmp_folder(self):
380-
root_dir = None
381-
if self.local_config and 'container_root' in self.local_config:
382-
root_dir = os.path.abspath(self.local_config['container_root'])
401+
root_dir = get_config_value('local.container_root', self.sagemaker_session.config)
402+
if root_dir:
403+
root_dir = os.path.abspath(root_dir)
383404

384405
dir = tempfile.mkdtemp(dir=root_dir)
385406

@@ -565,6 +586,10 @@ def _ecr_login_if_needed(boto_session, image):
565586
if _check_output('docker images -q %s' % image).strip():
566587
return
567588

589+
if not boto_session:
590+
raise RuntimeError('A boto session is required to login to ECR.'
591+
'Please pull the image: %s manually.' % image)
592+
568593
ecr = boto_session.client('ecr')
569594
auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]])
570595
authorization_data = auth['authorizationData'][0]

src/sagemaker/local/local_session.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import platform
1616
import time
1717

18+
import boto3
1819
import urllib3
1920
from botocore.exceptions import ClientError
2021

2122
from sagemaker.local.image import _SageMakerContainer
2223
from sagemaker.session import Session
24+
from sagemaker.utils import get_config_value
2325

2426
logger = logging.getLogger(__name__)
2527
logger.setLevel(logging.WARNING)
@@ -115,9 +117,7 @@ def create_endpoint(self, EndpointName, EndpointConfigName):
115117

116118
i = 0
117119
http = urllib3.PoolManager()
118-
serving_port = 8080
119-
if self.sagemaker_session.config and 'local' in self.sagemaker_session.config:
120-
serving_port = self.sagemaker_session.config['local'].get('serving_port', 8080)
120+
serving_port = get_config_value('local.serving_port', self.sagemaker_session.config) or 8080
121121
endpoint_url = "http://localhost:%s/ping" % serving_port
122122
while True:
123123
i += 1
@@ -153,8 +153,8 @@ def __init__(self, config=None):
153153
"""
154154
self.http = urllib3.PoolManager()
155155
self.serving_port = 8080
156-
if config and 'local' in config:
157-
self.serving_port = config['local'].get('serving_port', 8080)
156+
self.config = config
157+
self.serving_port = get_config_value('local.serving_port', config) or 8080
158158

159159
def invoke_endpoint(self, Body, EndpointName, ContentType, Accept):
160160
url = "http://localhost:%s/invocations" % self.serving_port
@@ -171,8 +171,19 @@ def __init__(self, boto_session=None):
171171

172172
if platform.system() == 'Windows':
173173
logger.warning("Windows Support for Local Mode is Experimental")
174+
175+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
176+
"""Initialize this Local SageMaker Session."""
177+
178+
self.boto_session = boto_session or boto3.Session()
179+
self._region_name = self.boto_session.region_name
180+
181+
if self._region_name is None:
182+
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
183+
174184
self.sagemaker_client = LocalSagemakerClient(self)
175185
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
186+
self.local_mode = True
176187

177188
def logs_for_job(self, job_name, wait=False, poll=5):
178189
# override logs_for_job() as it doesn't need to perform any action

0 commit comments

Comments
 (0)