diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 11be13fc88..0f3f432021 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -30,7 +30,6 @@ from sagemaker.local import LocalSession from sagemaker.utils import base_name_from_image, name_from_base from sagemaker.session import Session -from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import from sagemaker.workflow.properties import Properties from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.entities import Expression @@ -219,14 +218,14 @@ def _normalize_args( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) - inputs_with_code = self._include_code_in_inputs(inputs, code) + inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key) normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key) normalized_outputs = self._normalize_outputs(outputs) self.arguments = arguments return normalized_inputs, normalized_outputs - def _include_code_in_inputs(self, inputs, _code): + def _include_code_in_inputs(self, inputs, _code, _kms_key): """A no op in the base class to include code in the processing job inputs. Args: @@ -235,6 +234,8 @@ def _include_code_in_inputs(self, inputs, _code): :class:`~sagemaker.processing.ProcessingInput` objects. _code (str): This can be an S3 URI or a local path to a file with the framework script to run (default: None). A no op in the base class. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). Returns: list[:class:`~sagemaker.processing.ProcessingInput`]: inputs @@ -528,7 +529,7 @@ def run( if wait: self.latest_job.wait(logs=logs) - def _include_code_in_inputs(self, inputs, code): + def _include_code_in_inputs(self, inputs, code, kms_key=None): """Converts code to appropriate input and includes in input list. Side effects include: @@ -541,12 +542,14 @@ def _include_code_in_inputs(self, inputs, code): :class:`~sagemaker.processing.ProcessingInput` objects. code (str): This can be an S3 URI or a local path to a file with the framework script to run (default: None). + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). Returns: list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the code as `ProcessingInput`. """ - user_code_s3_uri = self._handle_user_code_url(code) + user_code_s3_uri = self._handle_user_code_url(code, kms_key) user_script_name = self._get_user_code_name(code) inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri) @@ -567,7 +570,7 @@ def _get_user_code_name(self, code): code_url = urlparse(code) return os.path.basename(code_url.path) - def _handle_user_code_url(self, code): + def _handle_user_code_url(self, code, kms_key=None): """Gets the S3 URL containing the user's code. Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing @@ -575,6 +578,8 @@ def _handle_user_code_url(self, code): Args: code (str): A URL to the customer's code. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). Returns: str: The S3 URL to the customer's code. @@ -603,7 +608,7 @@ def _handle_user_code_url(self, code): code ) ) - user_code_s3_uri = self._upload_code(code_path) + user_code_s3_uri = self._upload_code(code_path, kms_key) else: raise ValueError( "code {} url scheme {} is not recognized. Please pass a file path or S3 url".format( @@ -612,11 +617,13 @@ def _handle_user_code_url(self, code): ) return user_code_s3_uri - def _upload_code(self, code): + def _upload_code(self, code, kms_key=None): """Uploads a code file or directory specified as a string and returns the S3 URI. Args: code (str): A file or directory to be uploaded to S3. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). Returns: str: The S3 URI of the uploaded file or directory. @@ -630,7 +637,10 @@ def _upload_code(self, code): self._CODE_CONTAINER_INPUT_NAME, ) return s3.S3Uploader.upload( - local_path=code, desired_s3_uri=desired_s3_uri, sagemaker_session=self.sagemaker_session + local_path=code, + desired_s3_uri=desired_s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, ) def _convert_code_and_add_to_inputs(self, inputs, s3_uri): @@ -666,7 +676,9 @@ def _set_entrypoint(self, command, user_script_name): """ user_script_location = str( pathlib.PurePosixPath( - self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name + self._CODE_CONTAINER_BASE_PATH, + self._CODE_CONTAINER_INPUT_NAME, + user_script_name, ) ) self.entrypoint = command + [user_script_location] @@ -1066,7 +1078,10 @@ def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" # Create the request dictionary. - s3_input_request = {"InputName": self.input_name, "AppManaged": self.app_managed} + s3_input_request = { + "InputName": self.input_name, + "AppManaged": self.app_managed, + } if self.s3_input: # Check the compression type, then add it to the dictionary.