-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add support for additional files #494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d2a7f5
7326b22
df542d3
e684021
ef0cc4e
2d27e67
a04affc
75d5715
cf52ca8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,9 +64,9 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_ | |
set to the number of GPUs on the instance (on GPU instances), or one (on CPU instances). | ||
additional_mpi_options (str): String of options to the 'mpirun' command used to run the entry point. | ||
For example, '-X NCCL_DEBUG=WARN' will pass that option string to the mpirun command. | ||
source_dir (str): Path (absolute or relative) to a directory with any other training | ||
source code dependencies aside from tne entry point file (default: None). Structure within this | ||
directory are preserved when training on Amazon SageMaker. | ||
source_dir (str or [str]): Single path (absolute or relative) or a list of paths to directories with | ||
any other training source code dependencies aside from the entry point file (default: None). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd change this line to "any source code (other than the entry point file) needed for training" |
||
The structures within this directories are preserved when training on Amazon SageMaker. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we include explanation about the structure of how to access each of the directories? i.e. I assume (but am still reading the PR) that it'll be something like:
based on reading the docstring, but it'd be good to be explicit about it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, more documentation and examples would be helpful. |
||
hyperparameters (dict): Hyperparameters that will be used for training (default: None). | ||
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. | ||
For convenience, this accepts other types for keys and values, but ``str()`` will be called | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,7 +107,7 @@ def validate_source_dir(script, directory): | |
return True | ||
|
||
|
||
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory): | ||
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, additional_files=None): | ||
"""Pack and upload source files to S3 only if directory is empty or local. | ||
|
||
Note: | ||
|
@@ -118,31 +118,43 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory): | |
bucket (str): S3 bucket to which the compressed file is uploaded. | ||
s3_key_prefix (str): Prefix for the S3 key. | ||
script (str): Script filename. | ||
directory (str): Directory containing the source file. If it starts with "s3://", no action is taken. | ||
directory (str or None): Directory containing the source file. If it starts with "s3://", no action is taken. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
Returns: | ||
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. | ||
""" | ||
if directory: | ||
if directory.lower().startswith("s3://"): | ||
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script)) | ||
else: | ||
script_name = script | ||
source_files = [os.path.join(directory, name) for name in os.listdir(directory)] | ||
key = '%s/sourcedir.tar.gz' % s3_key_prefix | ||
|
||
if directory and directory.lower().startswith("s3://"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. single quotes for the string |
||
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script)) | ||
else: | ||
# If no directory is specified, the script parameter needs to be a valid relative path. | ||
os.path.exists(script) | ||
script_name = os.path.basename(script) | ||
source_files = [script] | ||
source_files = _list_root_files(script, directory, additional_files) | ||
_upload_code(session, bucket, key, source_files) | ||
|
||
script_name = script if directory else os.path.basename(script) | ||
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name) | ||
|
||
s3 = session.resource('s3') | ||
key = '{}/{}'.format(s3_key_prefix, 'sourcedir.tar.gz') | ||
|
||
def _upload_code(session, bucket, key, source_files): | ||
tar_file = sagemaker.utils.create_tar_file(source_files) | ||
s3.Object(bucket, key).upload_file(tar_file) | ||
os.remove(tar_file) | ||
|
||
return UploadedCode(s3_prefix='s3://{}/{}'.format(bucket, key), script_name=script_name) | ||
try: | ||
session.resource('s3').Object(bucket, key).upload_file(tar_file) | ||
finally: | ||
os.remove(tar_file) | ||
|
||
|
||
def _list_root_files(script, directory, additional_files): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not independently unit tested. This logic is confusing (I'm not sure I completely understand it). I suggest unit testing this and providing some developer documentation describing the contract. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am personally against testing private methods, it adds unnecessary coupling to non public facing signatures and does not provide a good grasp of the functionality. I do agree that testing should be extensive, and cover any edge cases. The tests that I wrote are here https://github.com/aws/sagemaker-python-sdk/pull/494/files#diff-3108f99e19f25f4c77ad4f63d486b174R147 Let me know if you any suggestions of improvement of these methods. |
||
additional_files = additional_files or [] | ||
basedir = directory if directory else os.path.dirname(script) | ||
files = [basedir] + additional_files | ||
|
||
for file in files: | ||
if os.path.isfile(file): | ||
yield file | ||
else: | ||
for name in os.listdir(file): | ||
yield os.path.join(file, name) | ||
|
||
|
||
def framework_name_from_image(image_name): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,9 +47,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio | |
Args: | ||
entry_point (str): Path (absolute or relative) to the Python source file which should be executed | ||
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. | ||
source_dir (str): Path (absolute or relative) to a directory with any other training | ||
source code dependencies aside from tne entry point file (default: None). Structure within this | ||
directory are preserved when training on Amazon SageMaker. | ||
source_dir (str or [str]): Single path (absolute or relative) or a list of paths to directories with | ||
any other training source code dependencies aside from the entry point file (default: None). | ||
The structures within this directories are preserved when training on Amazon SageMaker. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these directories |
||
hyperparameters (dict): Hyperparameters that will be used for training (default: None). | ||
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. | ||
For convenience, this accepts other types for keys and values, but ``str()`` will be called | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import alexa | ||
|
||
|
||
def model_fn(anything): | ||
return alexa | ||
|
||
|
||
def predict_fn(input_object, model): | ||
return input_object | ||
|
||
|
||
if __name__ == '__main__': | ||
with open('/opt/ml/model/answer', 'w') as model: | ||
model.write(str(alexa.question('How many roads must a man walk down?'))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
|
||
from sagemaker.pytorch.estimator import PyTorch | ||
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES | ||
from tests.integ.timeout import timeout | ||
|
||
|
||
def test_source_dirs(sagemaker_session, tmpdir): | ||
source_dir = os.path.join(DATA_DIR, 'pytorch_source_dirs') | ||
lib = os.path.join(str(tmpdir), 'alexa.py') | ||
|
||
with open(lib, 'w') as f: | ||
f.write('def question(to_anything): return 42') | ||
|
||
instance_type = 'local' | ||
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): | ||
estimator = PyTorch(entry_point='train.py', role='SageMakerRole', source_dir=[source_dir, lib], | ||
py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type=instance_type) | ||
|
||
estimator.fit() | ||
|
||
predictor = estimator.deploy(initial_instance_count=1, instance_type=instance_type) | ||
|
||
predict_response = predictor.predict([24]) | ||
|
||
assert predict_response == [24] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's be more explicit and use
list[str]
for lists. also I'd change it to "A single path"