Skip to content

Commit b096cd1

Browse files
authored
Add support for additional libraries in the Estimator (aws#498)
* Support for dependencies in the estimators
1 parent 53a43f6 commit b096cd1

22 files changed

+485
-149
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
CHANGELOG
33
=========
44

5-
1.15.1.dev
6-
==========
5+
1.15.1dev
6+
=========
77

8+
* feature: Estimators: dependencies attribute allows export of additional libraries into the container
89
* feature: Add APIs to export Airflow transform and deploy config
910

1011
1.15.0

src/sagemaker/chainer/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,23 @@ The following are optional arguments. When you create a ``Chainer`` object, you
149149
other training source code dependencies including the entry point
150150
file. Structure within this directory will be preserved when training
151151
on SageMaker.
152+
- ``dependencies (list[str])`` A list of paths to directories (absolute or relative) with
153+
any additional libraries that will be exported to the container (default: []).
154+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
155+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
156+
instead. Example:
157+
158+
The following call
159+
>>> Chainer(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
160+
results in the following inside the container:
161+
162+
>>> $ ls
163+
164+
>>> opt/ml/code
165+
>>> ├── train.py
166+
>>> ├── common
167+
>>> └── virtual-env
168+
152169
- ``hyperparameters`` Hyperparameters that will be used for training.
153170
Will be made accessible as a dict[str, str] to the training code on
154171
SageMaker. For convenience, accepts other types besides str, but

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
133133
py_version=self.py_version, framework_version=self.framework_version,
134134
model_server_workers=model_server_workers, image=self.image_name,
135135
sagemaker_session=self.sagemaker_session,
136-
vpc_config=self.get_vpc_config(vpc_config_override))
136+
vpc_config=self.get_vpc_config(vpc_config_override), dependencies=self.dependencies)
137137

138138
@classmethod
139139
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/estimator.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ class Framework(EstimatorBase):
637637
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
638638

639639
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
640-
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
640+
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
641641
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
642642
643643
Args:
@@ -646,6 +646,22 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
646646
source_dir (str): Path (absolute or relative) to a directory with any other training
647647
source code dependencies aside from tne entry point file (default: None). Structure within this
648648
directory are preserved when training on Amazon SageMaker.
649+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
650+
any additional libraries that will be exported to the container (default: []).
651+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
652+
Example:
653+
654+
The following call
655+
>>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
656+
results in the following inside the container:
657+
658+
>>> $ ls
659+
660+
>>> opt/ml/code
661+
>>> ├── train.py
662+
>>> ├── common
663+
>>> └── virtual-env
664+
649665
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
650666
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
651667
For convenience, this accepts other types for keys and values, but ``str()`` will be called
@@ -663,6 +679,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
663679
"""
664680
super(Framework, self).__init__(**kwargs)
665681
self.source_dir = source_dir
682+
self.dependencies = dependencies or []
666683
self.entry_point = entry_point
667684
if enable_cloudwatch_metrics:
668685
warnings.warn('enable_cloudwatch_metrics is now deprecated and will be removed in the future.',
@@ -729,7 +746,8 @@ def _stage_user_code_in_s3(self):
729746
bucket=code_bucket,
730747
s3_key_prefix=code_s3_prefix,
731748
script=self.entry_point,
732-
directory=self.source_dir)
749+
directory=self.source_dir,
750+
dependencies=self.dependencies)
733751

734752
def _model_source_dir(self):
735753
"""Get the appropriate value to pass as source_dir to model constructor on deploying

src/sagemaker/fw_utils.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414

1515
import os
1616
import re
17+
import shutil
18+
import tempfile
1719
from collections import namedtuple
1820
from six.moves.urllib.parse import urlparse
1921

2022
import sagemaker.utils
2123

24+
_TAR_SOURCE_FILENAME = 'source.tar.gz'
25+
2226
UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name'])
2327
"""sagemaker.fw_utils.UserCode: An object containing the S3 prefix and script name.
2428
@@ -107,7 +111,7 @@ def validate_source_dir(script, directory):
107111
return True
108112

109113

110-
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
114+
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, dependencies=None):
111115
"""Pack and upload source files to S3 only if directory is empty or local.
112116
113117
Note:
@@ -118,31 +122,38 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
118122
bucket (str): S3 bucket to which the compressed file is uploaded.
119123
s3_key_prefix (str): Prefix for the S3 key.
120124
script (str): Script filename.
121-
directory (str): Directory containing the source file. If it starts with "s3://", no action is taken.
125+
directory (str or None): Directory containing the source file. If it starts with
126+
"s3://", no action is taken.
127+
dependencies (List[str]): A list of paths to directories (absolute or relative)
128+
containing additional libraries that will be copied into
129+
/opt/ml/lib
122130
123131
Returns:
124132
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
125133
"""
126-
if directory:
127-
if directory.lower().startswith("s3://"):
128-
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
129-
else:
130-
script_name = script
131-
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
134+
dependencies = dependencies or []
135+
key = '%s/sourcedir.tar.gz' % s3_key_prefix
136+
137+
if directory and directory.lower().startswith('s3://'):
138+
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
132139
else:
133-
# If no directory is specified, the script parameter needs to be a valid relative path.
134-
os.path.exists(script)
135-
script_name = os.path.basename(script)
136-
source_files = [script]
140+
tmp = tempfile.mkdtemp()
141+
142+
try:
143+
source_files = _list_files_to_compress(script, directory) + dependencies
144+
tar_file = sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME))
145+
146+
session.resource('s3').Object(bucket, key).upload_file(tar_file)
147+
finally:
148+
shutil.rmtree(tmp)
137149

138-
s3 = session.resource('s3')
139-
key = '{}/{}'.format(s3_key_prefix, 'sourcedir.tar.gz')
150+
script_name = script if directory else os.path.basename(script)
151+
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
140152

141-
tar_file = sagemaker.utils.create_tar_file(source_files)
142-
s3.Object(bucket, key).upload_file(tar_file)
143-
os.remove(tar_file)
144153

145-
return UploadedCode(s3_prefix='s3://{}/{}'.format(bucket, key), script_name=script_name)
154+
def _list_files_to_compress(script, directory):
155+
basedir = directory if directory else os.path.dirname(script)
156+
return [os.path.join(basedir, name) for name in os.listdir(basedir)]
146157

147158

148159
def framework_name_from_image(image_name):

src/sagemaker/model.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import sagemaker
1818

19-
from sagemaker.local import LocalSession
20-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, model_code_key_prefix
21-
from sagemaker.session import Session
22-
from sagemaker.utils import name_from_image, get_config_value
19+
from sagemaker import local
20+
from sagemaker import fw_utils
21+
from sagemaker import session
22+
from sagemaker import utils
2323

2424

2525
class Model(object):
@@ -96,12 +96,12 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
9696
"""
9797
if not self.sagemaker_session:
9898
if instance_type in ('local', 'local_gpu'):
99-
self.sagemaker_session = LocalSession()
99+
self.sagemaker_session = local.LocalSession()
100100
else:
101-
self.sagemaker_session = Session()
101+
self.sagemaker_session = session.Session()
102102

103103
container_def = self.prepare_container_def(instance_type)
104-
self.name = self.name or name_from_image(container_def['Image'])
104+
self.name = self.name or utils.name_from_image(container_def['Image'])
105105
self.sagemaker_session.create_model(self.name, self.role, container_def, vpc_config=self.vpc_config)
106106
production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count)
107107
self.endpoint_name = endpoint_name or self.name
@@ -127,7 +127,7 @@ class FrameworkModel(Model):
127127

128128
def __init__(self, model_data, image, role, entry_point, source_dir=None, predictor_cls=None, env=None, name=None,
129129
enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None,
130-
sagemaker_session=None, **kwargs):
130+
sagemaker_session=None, dependencies=None, **kwargs):
131131
"""Initialize a ``FrameworkModel``.
132132
133133
Args:
@@ -140,6 +140,23 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
140140
source code dependencies aside from tne entry point file (default: None). Structure within this
141141
directory will be preserved when training on SageMaker.
142142
If the directory points to S3, no code will be uploaded and the S3 location will be used instead.
143+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
144+
any additional libraries that will be exported to the container (default: []).
145+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
146+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
147+
instead. Example:
148+
149+
The following call
150+
>>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
151+
results in the following inside the container:
152+
153+
>>> $ ls
154+
155+
>>> opt/ml/code
156+
>>> ├── train.py
157+
>>> ├── common
158+
>>> └── virtual-env
159+
143160
predictor_cls (callable[string, sagemaker.session.Session]): A function to call to create
144161
a predictor (default: None). If not None, ``deploy`` will return the result of invoking
145162
this function on the created endpoint name.
@@ -160,10 +177,11 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
160177
sagemaker_session=sagemaker_session, **kwargs)
161178
self.entry_point = entry_point
162179
self.source_dir = source_dir
180+
self.dependencies = dependencies or []
163181
self.enable_cloudwatch_metrics = enable_cloudwatch_metrics
164182
self.container_log_level = container_log_level
165183
if code_location:
166-
self.bucket, self.key_prefix = parse_s3_url(code_location)
184+
self.bucket, self.key_prefix = fw_utils.parse_s3_url(code_location)
167185
else:
168186
self.bucket, self.key_prefix = None, None
169187
self.uploaded_code = None
@@ -179,22 +197,24 @@ def prepare_container_def(self, instance_type): # pylint disable=unused-argumen
179197
Returns:
180198
dict[str, str]: A container definition object usable with the CreateModel API.
181199
"""
182-
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, self.image)
200+
deploy_key_prefix = fw_utils.model_code_key_prefix(self.key_prefix, self.name, self.image)
183201
self._upload_code(deploy_key_prefix)
184202
deploy_env = dict(self.env)
185203
deploy_env.update(self._framework_env_vars())
186204
return sagemaker.container_def(self.image, self.model_data, deploy_env)
187205

188206
def _upload_code(self, key_prefix):
189-
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
207+
local_code = utils.get_config_value('local.local_code', self.sagemaker_session.config)
190208
if self.sagemaker_session.local_mode and local_code:
191209
self.uploaded_code = None
192210
else:
193-
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
194-
bucket=self.bucket or self.sagemaker_session.default_bucket(),
195-
s3_key_prefix=key_prefix,
196-
script=self.entry_point,
197-
directory=self.source_dir)
211+
bucket = self.bucket or self.sagemaker_session.default_bucket()
212+
self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session,
213+
bucket=bucket,
214+
s3_key_prefix=key_prefix,
215+
script=self.entry_point,
216+
directory=self.source_dir,
217+
dependencies=self.dependencies)
198218

199219
def _framework_env_vars(self):
200220
if self.uploaded_code:

src/sagemaker/mxnet/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,23 @@ The following are optional arguments. When you create an ``MXNet`` object, you c
271271
other training source code dependencies including the entry point
272272
file. Structure within this directory will be preserved when training
273273
on SageMaker.
274+
- ``dependencies (list[str])`` A list of paths to directories (absolute or relative) with
275+
any additional libraries that will be exported to the container (default: []).
276+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
277+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
278+
instead. Example:
279+
280+
The following call
281+
>>> MXNet(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
282+
results in the following inside the container:
283+
284+
>>> $ ls
285+
286+
>>> opt/ml/code
287+
>>> ├── train.py
288+
>>> ├── common
289+
>>> └── virtual-env
290+
274291
- ``hyperparameters`` Hyperparameters that will be used for training.
275292
Will be made accessible as a dict[str, str] to the training code on
276293
SageMaker. For convenience, accepts other types besides str, but

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
115115
container_log_level=self.container_log_level, code_location=self.code_location,
116116
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
117117
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session,
118-
vpc_config=self.get_vpc_config(vpc_config_override))
118+
vpc_config=self.get_vpc_config(vpc_config_override), dependencies=self.dependencies)
119119

120120
@classmethod
121121
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/pytorch/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,23 @@ The following are optional arguments. When you create a ``PyTorch`` object, you
175175
other training source code dependencies including the entry point
176176
file. Structure within this directory will be preserved when training
177177
on SageMaker.
178+
- ``dependencies (list[str])`` A list of paths to directories (absolute or relative) with
179+
any additional libraries that will be exported to the container (default: []).
180+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
181+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
182+
instead. Example:
183+
184+
The following call
185+
>>> PyTorch(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
186+
results in the following inside the container:
187+
188+
>>> $ ls
189+
190+
>>> opt/ml/code
191+
>>> ├── train.py
192+
>>> ├── common
193+
>>> └── virtual-env
194+
178195
- ``hyperparameters`` Hyperparameters that will be used for training.
179196
Will be made accessible as a dict[str, str] to the training code on
180197
SageMaker. For convenience, accepts other types besides strings, but

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
9696
container_log_level=self.container_log_level, code_location=self.code_location,
9797
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
9898
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session,
99-
vpc_config=self.get_vpc_config(vpc_config_override))
99+
vpc_config=self.get_vpc_config(vpc_config_override), dependencies=self.dependencies)
100100

101101
@classmethod
102102
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/tensorflow/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,23 @@ you can specify these as keyword arguments.
409409
other training source code dependencies including the entry point
410410
file. Structure within this directory will be preserved when training
411411
on SageMaker.
412+
- ``dependencies (list[str])`` A list of paths to directories (absolute or relative) with
413+
any additional libraries that will be exported to the container (default: []).
414+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
415+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
416+
instead. Example:
417+
418+
The following call
419+
>>> TensorFlow(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
420+
results in the following inside the container:
421+
422+
>>> $ ls
423+
424+
>>> opt/ml/code
425+
>>> ├── train.py
426+
>>> ├── common
427+
>>> └── virtual-env
428+
412429
- ``requirements_file (str)`` Path to a ``requirements.txt`` file. The path should
413430
be within and relative to ``source_dir``. This is a file containing a list of items to be
414431
installed using pip install. Details on the format can be found in the

0 commit comments

Comments
 (0)