Skip to content

fix: add kms key for processing job code upload #2329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from sagemaker.local import LocalSession
from sagemaker.utils import base_name_from_image, name_from_base
from sagemaker.session import Session
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.entities import Expression
Expand Down Expand Up @@ -219,14 +218,14 @@ def _normalize_args(
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

inputs_with_code = self._include_code_in_inputs(inputs, code)
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)
normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key)
normalized_outputs = self._normalize_outputs(outputs)
self.arguments = arguments

return normalized_inputs, normalized_outputs

def _include_code_in_inputs(self, inputs, _code):
def _include_code_in_inputs(self, inputs, _code, _kms_key):
"""A no op in the base class to include code in the processing job inputs.

Args:
Expand All @@ -235,6 +234,8 @@ def _include_code_in_inputs(self, inputs, _code):
:class:`~sagemaker.processing.ProcessingInput` objects.
_code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None). A no op in the base class.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs
Expand Down Expand Up @@ -528,7 +529,7 @@ def run(
if wait:
self.latest_job.wait(logs=logs)

def _include_code_in_inputs(self, inputs, code):
def _include_code_in_inputs(self, inputs, code, kms_key=None):
"""Converts code to appropriate input and includes in input list.

Side effects include:
Expand All @@ -541,12 +542,14 @@ def _include_code_in_inputs(self, inputs, code):
:class:`~sagemaker.processing.ProcessingInput` objects.
code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
code as `ProcessingInput`.
"""
user_code_s3_uri = self._handle_user_code_url(code)
user_code_s3_uri = self._handle_user_code_url(code, kms_key)
user_script_name = self._get_user_code_name(code)

inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
Expand All @@ -567,14 +570,16 @@ def _get_user_code_name(self, code):
code_url = urlparse(code)
return os.path.basename(code_url.path)

def _handle_user_code_url(self, code):
def _handle_user_code_url(self, code, kms_key=None):
"""Gets the S3 URL containing the user's code.

Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
for absolute or local file paths. Uploads the code to S3 if the code is a local file.

Args:
code (str): A URL to the customer's code.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
str: The S3 URL to the customer's code.
Expand Down Expand Up @@ -603,7 +608,7 @@ def _handle_user_code_url(self, code):
code
)
)
user_code_s3_uri = self._upload_code(code_path)
user_code_s3_uri = self._upload_code(code_path, kms_key)
else:
raise ValueError(
"code {} url scheme {} is not recognized. Please pass a file path or S3 url".format(
Expand All @@ -612,11 +617,13 @@ def _handle_user_code_url(self, code):
)
return user_code_s3_uri

def _upload_code(self, code):
def _upload_code(self, code, kms_key=None):
"""Uploads a code file or directory specified as a string and returns the S3 URI.

Args:
code (str): A file or directory to be uploaded to S3.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
str: The S3 URI of the uploaded file or directory.
Expand All @@ -630,7 +637,10 @@ def _upload_code(self, code):
self._CODE_CONTAINER_INPUT_NAME,
)
return s3.S3Uploader.upload(
local_path=code, desired_s3_uri=desired_s3_uri, sagemaker_session=self.sagemaker_session
local_path=code,
desired_s3_uri=desired_s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)

def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
Expand Down Expand Up @@ -666,7 +676,9 @@ def _set_entrypoint(self, command, user_script_name):
"""
user_script_location = str(
pathlib.PurePosixPath(
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
self._CODE_CONTAINER_BASE_PATH,
self._CODE_CONTAINER_INPUT_NAME,
user_script_name,
)
)
self.entrypoint = command + [user_script_location]
Expand Down Expand Up @@ -1066,7 +1078,10 @@ def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""

# Create the request dictionary.
s3_input_request = {"InputName": self.input_name, "AppManaged": self.app_managed}
s3_input_request = {
"InputName": self.input_name,
"AppManaged": self.app_managed,
}

if self.s3_input:
# Check the compression type, then add it to the dictionary.
Expand Down