Skip to content

Commit a00e3ff

Browse files
claytonparnellmizanfiu
authored andcommitted
fix: Fix bug forcing uploaded tar to be named sourcedir (aws#3412)
1 parent e567682 commit a00e3ff

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

src/sagemaker/processing.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1587,13 +1587,13 @@ def run( # type: ignore[override]
15871587
framework script to run.Path (absolute or relative) to the local
15881588
Python source file which should be executed as the entry point
15891589
to training. When `code` is an S3 URI, ignore `source_dir`,
1590-
`dependencies, and `git_config`. If ``source_dir`` is specified,
1590+
`dependencies`, and `git_config`. If ``source_dir`` is specified,
15911591
then ``code`` must point to a file located at the root of ``source_dir``.
15921592
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
15931593
with any other processing source code dependencies aside from the entry
15941594
point file (default: None). If ``source_dir`` is an S3 URI, it must
1595-
point to a tar.gz file. Structure within this directory are preserved
1596-
when processing on Amazon SageMaker (default: None).
1595+
point to a file named `sourcedir.tar.gz`. Structure within this directory
1596+
are preserved when processing on Amazon SageMaker (default: None).
15971597
dependencies (list[str]): A list of paths to directories (absolute
15981598
or relative) with any additional libraries that will be exported
15991599
to the container (default: []). The library folders will be
@@ -1730,12 +1730,15 @@ def _pack_and_upload_code(
17301730
"sagemaker_session unspecified when creating your Processor to have one set up "
17311731
"automatically."
17321732
)
1733+
if "/sourcedir.tar.gz" in estimator.uploaded_code.s3_prefix:
1734+
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
1735+
entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace(
1736+
"sourcedir.tar.gz",
1737+
"runproc.sh",
1738+
)
1739+
else:
1740+
raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`")
17331741

1734-
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
1735-
entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace(
1736-
"sourcedir.tar.gz",
1737-
"runproc.sh",
1738-
)
17391742
script = estimator.uploaded_code.script_name
17401743
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
17411744
self._generate_framework_script(script),

tests/integ/test_xgboost.py

+20
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,26 @@ def xgboost_training_job(
4040
)
4141

4242

43+
def test_sourcedir_naming(
44+
sagemaker_session,
45+
xgboost_latest_version,
46+
xgboost_latest_py_version,
47+
cpu_instance_type,
48+
):
49+
with pytest.raises(RuntimeError):
50+
processor = XGBoostProcessor(
51+
framework_version=xgboost_latest_version,
52+
role=ROLE,
53+
instance_count=1,
54+
instance_type=cpu_instance_type,
55+
sagemaker_session=sagemaker_session,
56+
)
57+
processor.run(
58+
source_dir="s3://bucket/deps.tar.gz",
59+
code="main_script.py",
60+
)
61+
62+
4363
@pytest.mark.release
4464
def test_framework_processing_job_with_deps(
4565
sagemaker_session,

0 commit comments

Comments
 (0)