Skip to content

Commit dd48ee6

Browse files
authored
Merge pull request aws#11 from verdimrc/fp-get-run-args
FrameworkProcessor.get_run_args()
2 parents ea95f87 + 05bf6d7 commit dd48ee6

File tree

3 files changed

+151
-55
lines changed

3 files changed

+151
-55
lines changed

src/sagemaker/processing.py

+102-24
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def __init__(
12301230
instance_type,
12311231
py_version="py3", # New kwarg
12321232
image_uri=None,
1233-
command=["python"],
1233+
command=["python3"],
12341234
volume_size_in_gb=30,
12351235
volume_kms_key=None,
12361236
output_kms_key=None,
@@ -1359,6 +1359,60 @@ def _pre_init_normalization(
13591359

13601360
return image_uri, base_job_name
13611361

1362+
def get_run_args(
1363+
self,
1364+
code,
1365+
source_dir=None,
1366+
dependencies=None,
1367+
git_config=None,
1368+
inputs=None,
1369+
outputs=None,
1370+
arguments=None,
1371+
job_name=None,
1372+
):
1373+
"""Returns a RunArgs object.
1374+
1375+
This object contains the normalized inputs, outputs and arguments needed
1376+
when using a ``FrameworkProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
1377+
1378+
Args:
1379+
code (str): This can be an S3 URI or a local path to a file with the framework
1380+
script to run. See the ``code`` argument in
1381+
`sagemaker.processing.FrameworkProcessor.run()`.
1382+
source_dir (str): Path (absolute, relative, or an S3 URI) to a directory wit
1383+
any other processing source code dependencies aside from the entrypoint
1384+
file (default: None). See the ``source_dir`` argument in
1385+
`sagemaker.processing.FrameworkProcessor.run()`
1386+
dependencies (list[str]): A list of paths to directories (absolute or relative)
1387+
with any additional libraries that will be exported to the container
1388+
(default: []). See the ``dependencies`` argument in
1389+
`sagemaker.processing.FrameworkProcessor.run()`.
1390+
git_config (dict[str, str]): Git configurations used for cloning files. See the
1391+
`git_config` argument in `sagemaker.processing.FrameworkProcessor.run()`.
1392+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1393+
the processing job. These must be provided as
1394+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1395+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1396+
the processing job. These can be specified as either path strings or
1397+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1398+
arguments (list[str]): A list of string arguments to be passed to a
1399+
processing job (default: None).
1400+
job_name (str): Processing job name. If not specified, the processor generates
1401+
a default job name, based on the base job name and current timestamp.
1402+
"""
1403+
# When job_name is None, the job_name to upload code (+payload) will
1404+
# differ from job_name used by run().
1405+
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1406+
code, source_dir, dependencies, git_config, job_name, inputs
1407+
)
1408+
1409+
return RunArgs(
1410+
s3_runproc_sh,
1411+
inputs=inputs,
1412+
outputs=outputs,
1413+
arguments=arguments,
1414+
)
1415+
13621416
def run( # type: ignore[override]
13631417
self,
13641418
code,
@@ -1377,15 +1431,17 @@ def run( # type: ignore[override]
13771431
"""Runs a processing job.
13781432
13791433
Args:
1380-
code (str): Path (absolute or relative) to the local Python source
1381-
file which should be executed as the entry point to training. If
1382-
``source_dir`` is specified, then ``code`` must point to a file
1383-
located at the root of ``source_dir``.
1434+
code (str): This can be an S3 URI or a local path to a file with the
1435+
framework script to run.Path (absolute or relative) to the local
1436+
Python source file which should be executed as the entry point
1437+
to training. When `code` is an S3 URI, ignore `source_dir`,
1438+
`dependencies, and `git_config`. If ``source_dir`` is specified,
1439+
then ``code`` must point to a file located at the root of ``source_dir``.
13841440
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
1385-
with any other training source code dependencies aside from the entry
1441+
with any other processing source code dependencies aside from the entry
13861442
point file (default: None). If ``source_dir`` is an S3 URI, it must
13871443
point to a tar.gz file. Structure within this directory are preserved
1388-
when training on Amazon SageMaker (default: None).
1444+
when processing on Amazon SageMaker (default: None).
13891445
dependencies (list[str]): A list of paths to directories (absolute
13901446
or relative) with any additional libraries that will be exported
13911447
to the container (default: []). The library folders will be
@@ -1461,12 +1517,40 @@ def run( # type: ignore[override]
14611517
kms_key (str): The ARN of the KMS key that is used to encrypt the
14621518
user code file (default: None).
14631519
"""
1464-
if job_name is None:
1465-
job_name = self._generate_current_job_name()
1520+
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1521+
code, source_dir, dependencies, git_config, job_name, inputs
1522+
)
14661523

1467-
estimator = self._upload_payload(code, source_dir, dependencies, git_config, job_name)
1524+
# Submit a processing job.
1525+
super().run(
1526+
code=s3_runproc_sh,
1527+
inputs=inputs,
1528+
outputs=outputs,
1529+
arguments=arguments,
1530+
wait=wait,
1531+
logs=logs,
1532+
job_name=job_name,
1533+
experiment_config=experiment_config,
1534+
kms_key=kms_key,
1535+
)
1536+
1537+
def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_name, inputs):
1538+
if code.startswith("s3://"):
1539+
return code, inputs, job_name
1540+
1541+
if job_name is None:
1542+
job_name = self._generate_current_job_name(job_name)
1543+
1544+
estimator = self._upload_payload(
1545+
code,
1546+
source_dir,
1547+
dependencies,
1548+
git_config,
1549+
job_name,
1550+
)
14681551
inputs = self._patch_inputs_with_payload(
1469-
inputs, estimator._hyperparameters["sagemaker_submit_directory"]
1552+
inputs,
1553+
estimator._hyperparameters["sagemaker_submit_directory"],
14701554
)
14711555

14721556
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
@@ -1490,18 +1574,7 @@ def run( # type: ignore[override]
14901574
)
14911575
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
14921576

1493-
# Submit a processing job.
1494-
super().run(
1495-
code=s3_runproc_sh,
1496-
inputs=inputs,
1497-
outputs=outputs,
1498-
arguments=arguments,
1499-
wait=wait,
1500-
logs=logs,
1501-
job_name=job_name,
1502-
experiment_config=experiment_config,
1503-
kms_key=kms_key,
1504-
)
1577+
return s3_runproc_sh, inputs, job_name
15051578

15061579
def _generate_framework_script(self, user_script: str) -> str:
15071580
"""Generate the framework entrypoint file (as text) for a processing job.
@@ -1525,7 +1598,12 @@ def _generate_framework_script(self, user_script: str) -> str:
15251598
# Exit on any error. SageMaker uses error code to mark failed job.
15261599
set -e
15271600
1528-
[[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1601+
if [[ -f 'requirements.txt' ]]; then
1602+
# Some py3 containers has typing, which may breaks pip install
1603+
pip uninstall --yes typing
1604+
1605+
pip install -r requirements.txt
1606+
fi
15291607
15301608
{entry_point_command} {entry_point} "$@"
15311609
"""

src/sagemaker/sklearn/processing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
instance_type,
4949
py_version="py3", # New kwarg
5050
image_uri=None,
51-
command=["python"],
51+
command=["python3"],
5252
volume_size_in_gb=30,
5353
volume_kms_key=None,
5454
output_kms_key=None,

tests/unit/test_processing.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_sklearn_with_all_parameters(
162162
@patch("os.path.exists", return_value=True)
163163
@patch("os.path.isfile", return_value=True)
164164
def test_sklearn_with_all_parameters_via_run_args(
165-
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session
165+
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session, uploaded_code
166166
):
167167
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
168168
custom_command = ["Rscript"]
@@ -190,28 +190,31 @@ def test_sklearn_with_all_parameters_via_run_args(
190190
sagemaker_session=sagemaker_session,
191191
)
192192

193-
# FIXME: to check FrameworkProcessor.get_run_args(), and possibly fix with
194-
# source_dir, dependencies.
195-
run_args = processor.get_run_args(
196-
code="/local/path/to/processing_code.py",
197-
inputs=_get_data_inputs_all_parameters(),
198-
outputs=_get_data_outputs_all_parameters(),
199-
arguments=["--drop-columns", "'SelfEmployed'"],
200-
)
193+
with patch("sagemaker.estimator.tar_and_upload_dir", return_value=uploaded_code):
194+
run_args = processor.get_run_args(
195+
code="processing_code.py",
196+
source_dir="/local/path/to/source_dir",
197+
dependencies=["/local/path/to/dep_01"],
198+
git_config=None,
199+
inputs=_get_data_inputs_all_parameters(),
200+
outputs=_get_data_outputs_all_parameters(),
201+
arguments=["--drop-columns", "'SelfEmployed'"],
202+
)
201203

202-
processor.run(
203-
code=run_args.code,
204-
inputs=run_args.inputs,
205-
outputs=run_args.outputs,
206-
arguments=run_args.arguments,
207-
wait=True,
208-
logs=False,
209-
experiment_config={"ExperimentName": "AnExperiment"},
210-
)
204+
processor.run(
205+
code=run_args.code,
206+
inputs=run_args.inputs,
207+
outputs=run_args.outputs,
208+
arguments=run_args.arguments,
209+
wait=True,
210+
logs=False,
211+
experiment_config={"ExperimentName": "AnExperiment"},
212+
)
211213

212214
expected_args = _get_expected_args_all_parameters_modular_code(
213215
processor._current_job_name,
214216
instance_count=2,
217+
code_s3_prefix=run_args.code.replace("/runproc.sh", ""),
215218
)
216219
sklearn_image_uri = (
217220
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
@@ -235,7 +238,7 @@ def test_sklearn_with_all_parameters_via_run_args(
235238
@patch("os.path.exists", return_value=True)
236239
@patch("os.path.isfile", return_value=True)
237240
def test_sklearn_with_all_parameters_via_run_args_called_twice(
238-
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session
241+
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session, uploaded_code
239242
):
240243
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
241244

@@ -261,15 +264,22 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
261264
sagemaker_session=sagemaker_session,
262265
)
263266

264-
run_args = processor.get_run_args(
265-
code="/local/path/to/processing_code.py",
266-
inputs=_get_data_inputs_all_parameters(),
267-
outputs=_get_data_outputs_all_parameters(),
268-
arguments=["--drop-columns", "'SelfEmployed'"],
269-
)
267+
with patch("sagemaker.estimator.tar_and_upload_dir", return_value=uploaded_code):
268+
run_args = processor.get_run_args(
269+
code="processing_code.py",
270+
source_dir="/local/path/to/source_dir",
271+
dependencies=["/local/path/to/dep_01"],
272+
git_config=None,
273+
inputs=_get_data_inputs_all_parameters(),
274+
outputs=_get_data_outputs_all_parameters(),
275+
arguments=["--drop-columns", "'SelfEmployed'"],
276+
)
270277

271278
run_args = processor.get_run_args(
272279
code="/local/path/to/processing_code.py",
280+
source_dir=None,
281+
dependencies=None,
282+
git_config=None,
273283
inputs=_get_data_inputs_all_parameters(),
274284
outputs=_get_data_outputs_all_parameters(),
275285
arguments=["--drop-columns", "'SelfEmployed'"],
@@ -285,7 +295,10 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
285295
experiment_config={"ExperimentName": "AnExperiment"},
286296
)
287297

288-
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
298+
expected_args = _get_expected_args_all_parameters_modular_code(
299+
processor._current_job_name,
300+
code_s3_prefix=run_args.code.replace("/runproc.sh", ""),
301+
)
289302
sklearn_image_uri = (
290303
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
291304
).format(sklearn_version)
@@ -839,9 +852,14 @@ def _get_data_outputs_all_parameters():
839852

840853

841854
def _get_expected_args_all_parameters_modular_code(
842-
job_name, code_s3_uri=MOCKED_S3_URI, instance_count=1
855+
job_name,
856+
code_s3_uri=MOCKED_S3_URI,
857+
instance_count=1,
858+
code_s3_prefix=None,
843859
):
844-
# Add something to inputs
860+
if code_s3_prefix is None:
861+
code_s3_prefix = f"{code_s3_uri}/{job_name}/source"
862+
845863
return {
846864
"inputs": [
847865
{
@@ -911,7 +929,7 @@ def _get_expected_args_all_parameters_modular_code(
911929
"InputName": "code",
912930
"AppManaged": False,
913931
"S3Input": {
914-
"S3Uri": f"{code_s3_uri}/{job_name}/source/sourcedir.tar.gz",
932+
"S3Uri": f"{code_s3_prefix}/sourcedir.tar.gz",
915933
"LocalPath": "/opt/ml/processing/input/code/",
916934
"S3DataType": "S3Prefix",
917935
"S3InputMode": "File",
@@ -923,7 +941,7 @@ def _get_expected_args_all_parameters_modular_code(
923941
"InputName": "entrypoint",
924942
"AppManaged": False,
925943
"S3Input": {
926-
"S3Uri": f"{code_s3_uri}/{job_name}/source/runproc.sh",
944+
"S3Uri": f"{code_s3_prefix}/runproc.sh",
927945
"LocalPath": "/opt/ml/processing/input/entrypoint",
928946
"S3DataType": "S3Prefix",
929947
"S3InputMode": "File",

0 commit comments

Comments
 (0)