Skip to content

Commit aa62d3d

Browse files
committed
Support for lib dirs in the estimators
1 parent e37ac12 commit aa62d3d

File tree

16 files changed

+392
-51
lines changed

16 files changed

+392
-51
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.15.1dev
6+
=========
7+
8+
* feature: Estimators: lib_dirs attribute allows export of additional libraries into the container
9+
510
1.15.0
611
======
712

src/sagemaker/chainer/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,23 @@ The following are optional arguments. When you create a ``Chainer`` object, you
149149
other training source code dependencies including the entry point
150150
file. Structure within this directory will be preserved when training
151151
on SageMaker.
152+
- ``lib_dirs (list[str])`` A list of paths to directories (absolute or relative) with
153+
any additional libraries that will be exported to the container (default: []).
154+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
155+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
156+
instead. Example:
157+
158+
The following call
159+
>>> Estimator(entry_point='train.py', lib_dirs=['my/libs/common', 'virtual-env'])
160+
results in the following inside the container:
161+
162+
>>> $ ls
163+
164+
>>> opt/ml/code
165+
>>> ├── train.py
166+
>>> ├── common
167+
>>> └── virtual-env
168+
152169
- ``hyperparameters`` Hyperparameters that will be used for training.
153170
Will be made accessible as a dict[str, str] to the training code on
154171
SageMaker. For convenience, accepts other types besides str, but

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
133133
py_version=self.py_version, framework_version=self.framework_version,
134134
model_server_workers=model_server_workers, image=self.image_name,
135135
sagemaker_session=self.sagemaker_session,
136-
vpc_config=self.get_vpc_config(vpc_config_override))
136+
vpc_config=self.get_vpc_config(vpc_config_override), lib_dirs=self.lib_dirs)
137137

138138
@classmethod
139139
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/estimator.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ class Framework(EstimatorBase):
632632
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
633633

634634
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
635-
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
635+
container_log_level=logging.INFO, code_location=None, image_name=None, lib_dirs=None, **kwargs):
636636
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
637637
638638
Args:
@@ -641,6 +641,22 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
641641
source_dir (str): Path (absolute or relative) to a directory with any other training
642642
source code dependencies aside from tne entry point file (default: None). Structure within this
643643
directory are preserved when training on Amazon SageMaker.
644+
lib_dirs (list[str]): A list of paths to directories (absolute or relative) with
645+
any additional libraries that will be exported to the container (default: []).
646+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
647+
Example:
648+
649+
The following call
650+
>>> Estimator(entry_point='train.py', lib_dirs=['my/libs/common', 'virtual-env'])
651+
results in the following inside the container:
652+
653+
>>> $ ls
654+
655+
>>> opt/ml/code
656+
>>> ├── train.py
657+
>>> ├── common
658+
>>> └── virtual-env
659+
644660
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
645661
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
646662
For convenience, this accepts other types for keys and values, but ``str()`` will be called
@@ -658,6 +674,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
658674
"""
659675
super(Framework, self).__init__(**kwargs)
660676
self.source_dir = source_dir
677+
self.lib_dirs = lib_dirs or []
661678
self.entry_point = entry_point
662679
if enable_cloudwatch_metrics:
663680
warnings.warn('enable_cloudwatch_metrics is now deprecated and will be removed in the future.',
@@ -724,7 +741,8 @@ def _stage_user_code_in_s3(self):
724741
bucket=code_bucket,
725742
s3_key_prefix=code_s3_prefix,
726743
script=self.entry_point,
727-
directory=self.source_dir)
744+
directory=self.source_dir,
745+
lib_dirs=self.lib_dirs)
728746

729747
def _model_source_dir(self):
730748
"""Get the appropriate value to pass as source_dir to model constructor on deploying

src/sagemaker/fw_utils.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414

1515
import os
1616
import re
17+
import shutil
18+
import tempfile
1719
from collections import namedtuple
1820
from six.moves.urllib.parse import urlparse
1921

2022
import sagemaker.utils
2123

24+
_TAR_SOURCE_FILENAME = 'source.tar.gz'
25+
2226
UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name'])
2327
"""sagemaker.fw_utils.UserCode: An object containing the S3 prefix and script name.
2428
@@ -107,7 +111,7 @@ def validate_source_dir(script, directory):
107111
return True
108112

109113

110-
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
114+
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, lib_dirs=None):
111115
"""Pack and upload source files to S3 only if directory is empty or local.
112116
113117
Note:
@@ -118,31 +122,48 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
118122
bucket (str): S3 bucket to which the compressed file is uploaded.
119123
s3_key_prefix (str): Prefix for the S3 key.
120124
script (str): Script filename.
121-
directory (str): Directory containing the source file. If it starts with "s3://", no action is taken.
125+
directory (str or None): Directory containing the source file. If it starts with
126+
"s3://", no action is taken.
127+
lib_dirs (List[str]): A list of paths to directories (absolute or relative)
128+
containing additional libraries that will be copied into
129+
/opt/ml/lib
122130
123131
Returns:
124132
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
125133
"""
126-
if directory:
127-
if directory.lower().startswith("s3://"):
128-
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
129-
else:
130-
script_name = script
131-
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
134+
lib_dirs = lib_dirs or []
135+
key = '%s/sourcedir.tar.gz' % s3_key_prefix
136+
137+
if directory and directory.lower().startswith('s3://'):
138+
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
132139
else:
133-
# If no directory is specified, the script parameter needs to be a valid relative path.
134-
os.path.exists(script)
135-
script_name = os.path.basename(script)
136-
source_files = [script]
140+
tmp = tempfile.mkdtemp()
141+
142+
try:
143+
source_files = list(_expand_files_to_compress(script, directory))
137144

138-
s3 = session.resource('s3')
139-
key = '{}/{}'.format(s3_key_prefix, 'sourcedir.tar.gz')
145+
tar_file = sagemaker.utils.create_tar_file(source_files + lib_dirs, os.path.join(tmp, _TAR_SOURCE_FILENAME))
140146

141-
tar_file = sagemaker.utils.create_tar_file(source_files)
142-
s3.Object(bucket, key).upload_file(tar_file)
143-
os.remove(tar_file)
147+
session.resource('s3').Object(bucket, key).upload_file(tar_file)
144148

145-
return UploadedCode(s3_prefix='s3://{}/{}'.format(bucket, key), script_name=script_name)
149+
finally:
150+
shutil.rmtree(tmp)
151+
152+
script_name = script if directory else os.path.basename(script)
153+
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
154+
155+
156+
def _expand_files_to_compress(script, directory, additional_files=None):
157+
additional_files = additional_files or []
158+
basedir = directory if directory else os.path.dirname(script)
159+
files = [basedir] + additional_files
160+
161+
for file in files:
162+
if os.path.isfile(file):
163+
yield file
164+
else:
165+
for name in os.listdir(file):
166+
yield os.path.join(file, name)
146167

147168

148169
def framework_name_from_image(image_name):

src/sagemaker/model.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class FrameworkModel(Model):
127127

128128
def __init__(self, model_data, image, role, entry_point, source_dir=None, predictor_cls=None, env=None, name=None,
129129
enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None,
130-
sagemaker_session=None, **kwargs):
130+
sagemaker_session=None, lib_dirs=None, **kwargs):
131131
"""Initialize a ``FrameworkModel``.
132132
133133
Args:
@@ -140,6 +140,23 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
140140
source code dependencies aside from tne entry point file (default: None). Structure within this
141141
directory will be preserved when training on SageMaker.
142142
If the directory points to S3, no code will be uploaded and the S3 location will be used instead.
143+
lib_dirs (list[str]): A list of paths to directories (absolute or relative) with
144+
any additional libraries that will be exported to the container (default: []).
145+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
146+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
147+
instead. Example:
148+
149+
The following call
150+
>>> Estimator(entry_point='train.py', lib_dirs=['my/libs/common', 'virtual-env'])
151+
results in the following inside the container:
152+
153+
>>> $ ls
154+
155+
>>> opt/ml/code
156+
>>> ├── train.py
157+
>>> ├── common
158+
>>> └── virtual-env
159+
143160
predictor_cls (callable[string, sagemaker.session.Session]): A function to call to create
144161
a predictor (default: None). If not None, ``deploy`` will return the result of invoking
145162
this function on the created endpoint name.
@@ -160,6 +177,7 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
160177
sagemaker_session=sagemaker_session, **kwargs)
161178
self.entry_point = entry_point
162179
self.source_dir = source_dir
180+
self.lib_dirs = lib_dirs or []
163181
self.enable_cloudwatch_metrics = enable_cloudwatch_metrics
164182
self.container_log_level = container_log_level
165183
if code_location:
@@ -194,7 +212,8 @@ def _upload_code(self, key_prefix):
194212
bucket=self.bucket or self.sagemaker_session.default_bucket(),
195213
s3_key_prefix=key_prefix,
196214
script=self.entry_point,
197-
directory=self.source_dir)
215+
directory=self.source_dir,
216+
lib_dirs=self.lib_dirs)
198217

199218
def _framework_env_vars(self):
200219
if self.uploaded_code:

src/sagemaker/mxnet/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,23 @@ The following are optional arguments. When you create an ``MXNet`` object, you c
271271
other training source code dependencies including the entry point
272272
file. Structure within this directory will be preserved when training
273273
on SageMaker.
274+
- ``lib_dirs (list[str])`` A list of paths to directories (absolute or relative) with
275+
any additional libraries that will be exported to the container (default: []).
276+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
277+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
278+
instead. Example:
279+
280+
The following call
281+
>>> Estimator(entry_point='train.py', lib_dirs=['my/libs/common', 'virtual-env'])
282+
results in the following inside the container:
283+
284+
>>> $ ls
285+
286+
>>> opt/ml/code
287+
>>> ├── train.py
288+
>>> ├── common
289+
>>> └── virtual-env
290+
274291
- ``hyperparameters`` Hyperparameters that will be used for training.
275292
Will be made accessible as a dict[str, str] to the training code on
276293
SageMaker. For convenience, accepts other types besides str, but

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
115115
container_log_level=self.container_log_level, code_location=self.code_location,
116116
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
117117
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session,
118-
vpc_config=self.get_vpc_config(vpc_config_override))
118+
vpc_config=self.get_vpc_config(vpc_config_override), lib_dirs=self.lib_dirs)
119119

120120
@classmethod
121121
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/pytorch/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,23 @@ The following are optional arguments. When you create a ``PyTorch`` object, you
175175
other training source code dependencies including the entry point
176176
file. Structure within this directory will be preserved when training
177177
on SageMaker.
178+
- ``lib_dirs (list[str])`` A list of paths to directories (absolute or relative) with
179+
any additional libraries that will be exported to the container (default: []).
180+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
181+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
182+
instead. Example:
183+
184+
The following call
185+
>>> Estimator(entry_point='train.py', lib_dirs=['my/libs/common', 'virtual-env'])
186+
results in the following inside the container:
187+
188+
>>> $ ls
189+
190+
>>> opt/ml/code
191+
>>> ├── train.py
192+
>>> ├── common
193+
>>> └── virtual-env
194+
178195
- ``hyperparameters`` Hyperparameters that will be used for training.
179196
Will be made accessible as a dict[str, str] to the training code on
180197
SageMaker. For convenience, accepts other types besides strings, but

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
9696
container_log_level=self.container_log_level, code_location=self.code_location,
9797
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
9898
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session,
99-
vpc_config=self.get_vpc_config(vpc_config_override))
99+
vpc_config=self.get_vpc_config(vpc_config_override), lib_dirs=self.lib_dirs)
100100

101101
@classmethod
102102
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/tensorflow/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,23 @@ you can specify these as keyword arguments.
409409
other training source code dependencies including the entry point
410410
file. Structure within this directory will be preserved when training
411411
on SageMaker.
412+
- ``lib_dirs (list[str])`` A list of paths to directories (absolute or relative) with
413+
any additional libraries that will be exported to the container (default: []).
414+
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
415+
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
416+
instead. Example:
417+
418+
The following call
419+
>>> Estimator(entry_point='train.py', lib_dirs=['my/libs/common', 'virtual-env'])
420+
results in the following inside the container:
421+
422+
>>> $ ls
423+
424+
>>> opt/ml/code
425+
>>> ├── train.py
426+
>>> ├── common
427+
>>> └── virtual-env
428+
412429
- ``requirements_file (str)`` Path to a ``requirements.txt`` file. The path should
413430
be within and relative to ``source_dir``. This is a file containing a list of items to be
414431
installed using pip install. Details on the format can be found in the

src/sagemaker/tensorflow/estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ def _create_default_model(self, model_server_workers, role, vpc_config_override)
411411
framework_version=self.framework_version,
412412
model_server_workers=model_server_workers,
413413
sagemaker_session=self.sagemaker_session,
414-
vpc_config=self.get_vpc_config(vpc_config_override))
414+
vpc_config=self.get_vpc_config(vpc_config_override),
415+
lib_dirs=self.lib_dirs)
415416

416417
def hyperparameters(self):
417418
"""Return hyperparameters used by your custom TensorFlow code during model training."""
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# import os
14+
# import sys
15+
# import tarfile
16+
#
17+
# lib_dir = '/opt/ml/lib'
18+
#
19+
# if not os.path.exists(lib_dir):
20+
# os.makedirs(lib_dir)
21+
#
22+
# with tarfile.open(name=os.path.join(os.path.dirname(__file__), 'opt_ml_lib.tar.gz'), mode='r:gz') as t:
23+
# t.extractall(path=lib_dir)
24+
#
25+
# sys.path.insert(0, lib_dir)
26+
27+
import alexa
28+
29+
30+
def model_fn(anything):
31+
return alexa
32+
33+
34+
def predict_fn(input_object, model):
35+
return input_object
36+
37+
38+
if __name__ == '__main__':
39+
with open('/opt/ml/model/answer', 'w') as model:
40+
model.write(str(alexa.question('How many roads must a man walk down?')))

0 commit comments

Comments
 (0)