Skip to content

Commit ebbf6ac

Browse files
committed
Make code_location to be S3 URI instead of bucket in training_config()
1 parent b096cd1 commit ebbf6ac

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ CHANGELOG
77

88
* feature: Estimators: dependencies attribute allows export of additional libraries into the container
99
* feature: Add APIs to export Airflow transform and deploy config
10+
* bug-fix: Allow code_location argument to be S3 URI in training_config API
11+
*
1012

1113
1.15.0
1214
======

src/sagemaker/estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
670670
training jobs. This will be ignored for now and removed in a further release.
671671
container_log_level (int): Log level to use within the container (default: logging.INFO).
672672
Valid values are defined in the Python logging module.
673-
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None).
673+
code_location (str): The S3 URI where custom code is uploaded (default: None).
674674
If not specified, default bucket created by ``sagemaker.session.Session`` is used.
675+
The default S3 path is default_bucket/job-name/source/.
675676
image_name (str): An alternate image name to use instead of the official Sagemaker image
676677
for the framework. This is useful to run one of the Sagemaker supported frameworks
677678
with an image containing custom dependencies.

src/sagemaker/workflow/airflow.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ def prepare_framework(estimator, s3_operations):
2727
estimator (sagemaker.estimator.Estimator): The framework estimator to get information from and update.
2828
s3_operations (dict): The dict to specify s3 operations (upload `source_dir`).
2929
"""
30-
bucket = estimator.code_location if estimator.code_location else estimator.sagemaker_session._default_bucket
31-
key = '{}/source/sourcedir.tar.gz'.format(estimator._current_job_name)
30+
if estimator.code_location is not None:
31+
bucket, key = fw_utils.parse_s3_url(estimator.code_location)
32+
key = os.path.join(key, 'source', 'sourcedir.tar.gz')
33+
else:
34+
bucket = estimator.sagemaker_session._default_bucket
35+
key = os.path.join(estimator._current_job_name, 'source', 'sourcedir.tar.gz')
3236
script = os.path.basename(estimator.entry_point)
3337
if estimator.source_dir and estimator.source_dir.lower().startswith('s3://'):
3438
code_dir = estimator.source_dir

tests/unit/test_airflow.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def test_framework_training_config_all_args(sagemaker_session):
244244
source_dir="{{ source_dir }}",
245245
enable_cloudwatch_metrics=False,
246246
container_log_level="{{ log_level }}",
247-
code_location="{{ bucket_name }}",
247+
code_location="s3://{{ bucket_name }}/{{ prefix }}",
248248
training_steps=1000,
249249
evaluation_steps=100,
250250
checkpoint_path="{{ checkpoint_path }}",
@@ -304,9 +304,7 @@ def test_framework_training_config_all_args(sagemaker_session):
304304
'SecurityGroupIds': ['{{ security_group_ids }}']
305305
},
306306
'HyperParameters': {
307-
'sagemaker_submit_directory': '"s3://{{ bucket_name }}/{{ base_job_name }}-'
308-
'{{ execution_date.strftime(\'%Y-%m-%d-%H-%M-%S\') }}'
309-
'/source/sourcedir.tar.gz"',
307+
'sagemaker_submit_directory': '"s3://{{ bucket_name }}/{{ prefix }}/source/sourcedir.tar.gz"',
310308
'sagemaker_program': '"{{ entry_point }}"',
311309
'sagemaker_enable_cloudwatch_metrics': 'false',
312310
'sagemaker_container_log_level': '"{{ log_level }}"',
@@ -322,8 +320,7 @@ def test_framework_training_config_all_args(sagemaker_session):
322320
'S3Upload': [{
323321
'Path': '{{ source_dir }}',
324322
'Bucket': '{{ bucket_name }}',
325-
'Key': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
326-
"/source/sourcedir.tar.gz",
323+
'Key': "{{ prefix }}/source/sourcedir.tar.gz",
327324
'Tar': True}]
328325
}
329326
}

0 commit comments

Comments
 (0)