Skip to content

Commit a2226b9

Browse files
Merge branch 'master' into local_pull
2 parents 09270f2 + beece5a commit a2226b9

File tree

5 files changed

+20
-16
lines changed

5 files changed

+20
-16
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77

88
* feature: Estimators: dependencies attribute allows export of additional libraries into the container
99
* feature: Add APIs to export Airflow transform and deploy config
10+
* bug-fix: Allow code_location argument to be S3 URI in training_config API
1011
* enhancement: Local Mode: add explicit pull for serving
1112

1213
1.15.0

src/sagemaker/estimator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
658658
>>> $ ls
659659
660660
>>> opt/ml/code
661-
>>> ├── train.py
662-
>>> ├── common
663-
>>> └── virtual-env
661+
>>> |------ train.py
662+
>>> |------ common
663+
>>> |------ virtual-env
664664
665665
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
666666
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
@@ -670,8 +670,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
670670
training jobs. This will be ignored for now and removed in a further release.
671671
container_log_level (int): Log level to use within the container (default: logging.INFO).
672672
Valid values are defined in the Python logging module.
673-
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None).
674-
If not specified, default bucket created by ``sagemaker.session.Session`` is used.
673+
code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
674+
The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
675+
If not specified, the default code location is s3://default_bucket/job-name/. And code file
676+
uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
675677
image_name (str): An alternate image name to use instead of the official Sagemaker image
676678
for the framework. This is useful to run one of the Sagemaker supported frameworks
677679
with an image containing custom dependencies.

src/sagemaker/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
153153
>>> $ ls
154154
155155
>>> opt/ml/code
156-
>>> ├── train.py
157-
>>> ├── common
158-
>>> └── virtual-env
156+
>>> |------ train.py
157+
>>> |------ common
158+
>>> |------ virtual-env
159159
160160
predictor_cls (callable[string, sagemaker.session.Session]): A function to call to create
161161
a predictor (default: None). If not None, ``deploy`` will return the result of invoking

src/sagemaker/workflow/airflow.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ def prepare_framework(estimator, s3_operations):
2727
estimator (sagemaker.estimator.Estimator): The framework estimator to get information from and update.
2828
s3_operations (dict): The dict to specify s3 operations (upload `source_dir`).
2929
"""
30-
bucket = estimator.code_location if estimator.code_location else estimator.sagemaker_session._default_bucket
31-
key = '{}/source/sourcedir.tar.gz'.format(estimator._current_job_name)
30+
if estimator.code_location is not None:
31+
bucket, key = fw_utils.parse_s3_url(estimator.code_location)
32+
key = os.path.join(key, 'source', 'sourcedir.tar.gz')
33+
else:
34+
bucket = estimator.sagemaker_session._default_bucket
35+
key = os.path.join(estimator._current_job_name, 'source', 'sourcedir.tar.gz')
3236
script = os.path.basename(estimator.entry_point)
3337
if estimator.source_dir and estimator.source_dir.lower().startswith('s3://'):
3438
code_dir = estimator.source_dir

tests/unit/test_airflow.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def test_framework_training_config_all_args(sagemaker_session):
244244
source_dir="{{ source_dir }}",
245245
enable_cloudwatch_metrics=False,
246246
container_log_level="{{ log_level }}",
247-
code_location="{{ bucket_name }}",
247+
code_location="s3://{{ bucket_name }}/{{ prefix }}",
248248
training_steps=1000,
249249
evaluation_steps=100,
250250
checkpoint_path="{{ checkpoint_path }}",
@@ -304,9 +304,7 @@ def test_framework_training_config_all_args(sagemaker_session):
304304
'SecurityGroupIds': ['{{ security_group_ids }}']
305305
},
306306
'HyperParameters': {
307-
'sagemaker_submit_directory': '"s3://{{ bucket_name }}/{{ base_job_name }}-'
308-
'{{ execution_date.strftime(\'%Y-%m-%d-%H-%M-%S\') }}'
309-
'/source/sourcedir.tar.gz"',
307+
'sagemaker_submit_directory': '"s3://{{ bucket_name }}/{{ prefix }}/source/sourcedir.tar.gz"',
310308
'sagemaker_program': '"{{ entry_point }}"',
311309
'sagemaker_enable_cloudwatch_metrics': 'false',
312310
'sagemaker_container_log_level': '"{{ log_level }}"',
@@ -322,8 +320,7 @@ def test_framework_training_config_all_args(sagemaker_session):
322320
'S3Upload': [{
323321
'Path': '{{ source_dir }}',
324322
'Bucket': '{{ bucket_name }}',
325-
'Key': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
326-
"/source/sourcedir.tar.gz",
323+
'Key': "{{ prefix }}/source/sourcedir.tar.gz",
327324
'Tar': True}]
328325
}
329326
}

0 commit comments

Comments
 (0)