Skip to content

Commit f64f5a9

Browse files
committed
fix: allow Airflow enabled estimators to use absolute path entry_point
1 parent d0c784a commit f64f5a9

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/sagemaker/workflow/airflow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ def prepare_framework(estimator, s3_operations):
4545
code_dir = "s3://{}/{}".format(bucket, key)
4646
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
4747
s3_operations["S3Upload"] = [
48-
{"Path": estimator.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True}
48+
{
49+
"Path": estimator.source_dir or estimator.entry_point,
50+
"Bucket": bucket,
51+
"Key": key,
52+
"Tar": True,
53+
}
4954
]
5055
estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir
5156
estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script

tests/unit/test_airflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_byo_training_config_all_args(sagemaker_session):
164164
@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP))
165165
def test_framework_training_config_required_args(sagemaker_session):
166166
tf = tensorflow.TensorFlow(
167-
entry_point="{{ entry_point }}",
167+
entry_point="/some/script.py",
168168
framework_version="1.10.0",
169169
training_steps=1000,
170170
evaluation_steps=100,
@@ -206,7 +206,7 @@ def test_framework_training_config_required_args(sagemaker_session):
206206
"HyperParameters": {
207207
"sagemaker_submit_directory": '"s3://output/sagemaker-tensorflow-%s/source/sourcedir.tar.gz"'
208208
% TIME_STAMP,
209-
"sagemaker_program": '"{{ entry_point }}"',
209+
"sagemaker_program": '"script.py"',
210210
"sagemaker_enable_cloudwatch_metrics": "false",
211211
"sagemaker_container_log_level": "20",
212212
"sagemaker_job_name": '"sagemaker-tensorflow-%s"' % TIME_STAMP,
@@ -219,7 +219,7 @@ def test_framework_training_config_required_args(sagemaker_session):
219219
"S3Operations": {
220220
"S3Upload": [
221221
{
222-
"Path": "{{ entry_point }}",
222+
"Path": "/some/script.py",
223223
"Bucket": "output",
224224
"Key": "sagemaker-tensorflow-%s/source/sourcedir.tar.gz" % TIME_STAMP,
225225
"Tar": True,

0 commit comments

Comments
 (0)