Skip to content

Commit 425390e

Browse files
andremoellerknakad
authored andcommitted
fix: allow processing users to run code in s3 (#1167)
* fix: allow processing users to run code in s3 Add validation and unit tests for Processing * Update src/sagemaker/processing.py Co-Authored-By: Karim Nakad <[email protected]>
1 parent 98fa76d commit 425390e

File tree

3 files changed

+415
-316
lines changed

3 files changed

+415
-316
lines changed

src/sagemaker/processing.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,9 @@ def run(
369369
"""
370370
self._current_job_name = self._generate_current_job_name(job_name=job_name)
371371

372-
user_script_name = self._get_user_script_name(code)
373-
user_code_s3_uri = self._upload_code(code)
372+
user_code_s3_uri = self._handle_user_code_url(code)
373+
user_script_name = self._get_user_code_name(code)
374+
374375
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
375376

376377
self._set_entrypoint(self.command, user_script_name)
@@ -389,25 +390,59 @@ def run(
389390
if wait:
390391
self.latest_job.wait(logs=logs)
391392

392-
def _get_user_script_name(self, code):
393-
"""Finds the user script name using the provided code file,
394-
directory, or script name.
393+
def _get_user_code_name(self, code):
394+
"""Gets the basename of the user's code from the URL the customer provided.
395395
396396
Args:
397-
code (str): This can be an S3 uri or a local path to either
398-
a directory or a file.
397+
code (str): A URL to the user's code.
398+
399+
Returns:
400+
str: The basename of the user's code.
401+
402+
"""
403+
code_url = urlparse(code)
404+
return os.path.basename(code_url.path)
405+
406+
def _handle_user_code_url(self, code):
407+
"""Gets the S3 URL containing the user's code.
408+
409+
Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
410+
for absolute or local file paths. Uploads the code to S3 if the code is a local file.
411+
412+
Args:
413+
code (str): A URL to the customer's code.
399414
400415
Returns:
401-
str: The script name from the S3 uri or from the file found
402-
on the user's local machine.
416+
str: The S3 URL to the customer's code.
417+
403418
"""
404-
if os.path.isdir(code) is None or not os.path.splitext(code)[1]:
419+
code_url = urlparse(code)
420+
if code_url.scheme == "s3":
421+
user_code_s3_uri = code
422+
elif code_url.scheme == "" or code_url.scheme == "file":
423+
# Validate that the file exists locally and is not a directory.
424+
if not os.path.exists(code):
425+
raise ValueError(
426+
"""code {} wasn't found. Please make sure that the file exists.
427+
""".format(
428+
code
429+
)
430+
)
431+
if not os.path.isfile(code):
432+
raise ValueError(
433+
"""code {} must be a file, not a directory. Please pass a path to a file.
434+
""".format(
435+
code
436+
)
437+
)
438+
user_code_s3_uri = self._upload_code(code)
439+
else:
405440
raise ValueError(
406-
"""'code' must be a file, not a directory. Please pass a path to a file, not a
407-
directory.
408-
"""
441+
"code {} url scheme {} is not recognized. Please pass a file path or S3 url".format(
442+
code, code_url.scheme
443+
)
409444
)
410-
return os.path.basename(code)
445+
return user_code_s3_uri
411446

412447
def _upload_code(self, code):
413448
"""Uploads a code file or directory specified as a string

src/sagemaker/sklearn/processing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
class SKLearnProcessor(ScriptProcessor):
2727
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""
2828

29+
_valid_framework_versions = ["0.20.0"]
30+
2931
def __init__(
3032
self,
3133
framework_version,
@@ -84,6 +86,13 @@ def __init__(
8486
session = sagemaker_session or Session()
8587
region = session.boto_region_name
8688

89+
if framework_version not in self._valid_framework_versions:
90+
raise ValueError(
91+
"scikit-learn version {} is not supported. Supported versions are {}".format(
92+
framework_version, self._valid_framework_versions
93+
)
94+
)
95+
8796
if not command:
8897
command = ["python3"]
8998

0 commit comments

Comments
 (0)