diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 92d3ea673c..d8f7b08df0 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1761,7 +1761,7 @@ def _get_default_image_uri(region): class BaseliningJob(ProcessingJob): """Provides functionality to retrieve baseline-specific files output from baselining job.""" - def __init__(self, sagemaker_session, job_name, inputs, outputs): + def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None): """Initializes a Baselining job that tracks a baselining job kicked off by the suggest workflow. @@ -1773,12 +1773,18 @@ def __init__(self, sagemaker_session, job_name, inputs, outputs): job_name (str): Name of the Amazon SageMaker Model Monitoring Baselining Job. inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects. outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects. + output_kms_key (str): The output kms key associated with the job. Defaults to None + if not provided. """ self.inputs = inputs self.outputs = outputs super(BaseliningJob, self).__init__( - sagemaker_session=sagemaker_session, job_name=job_name, inputs=inputs, outputs=outputs + sagemaker_session=sagemaker_session, + job_name=job_name, + inputs=inputs, + outputs=outputs, + output_kms_key=output_kms_key, ) @classmethod @@ -1799,6 +1805,7 @@ def from_processing_job(cls, processing_job): processing_job.job_name, processing_job.inputs, processing_job.outputs, + processing_job.output_kms_key, ) def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_key=None): @@ -1881,7 +1888,7 @@ class MonitoringExecution(ProcessingJob): executions """ - def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key): + def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key=None): """Initializes a MonitoringExecution job that tracks a monitoring execution kicked off by an Amazon SageMaker Model Monitoring Schedule. @@ -1893,13 +1900,17 @@ def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key): job_name (str): The name of the monitoring execution job. output (sagemaker.Processing.ProcessingOutput): The output associated with the monitoring execution. - output_kms_key (str): The output kms key associated with the job. + output_kms_key (str): The output kms key associated with the job. Defaults to None + if not provided. """ self.output = output - self.output_kms_key = output_kms_key super(MonitoringExecution, self).__init__( - sagemaker_session=sagemaker_session, job_name=job_name, inputs=inputs, outputs=[output] + sagemaker_session=sagemaker_session, + job_name=job_name, + inputs=inputs, + outputs=[output], + output_kms_key=output_kms_key, ) @classmethod diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 952764b8bb..ddb902aec6 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -502,9 +502,24 @@ def _set_entrypoint(self, command, user_script_name): class ProcessingJob(_Job): """Provides functionality to start, describe, and stop processing jobs.""" - def __init__(self, sagemaker_session, job_name, inputs, outputs): + def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None): + """Initializes a Processing job. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using + the default AWS configuration chain. + job_name (str): Name of the Processing job. + inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects. + outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects. + output_kms_key (str): The output kms key associated with the job. Defaults to None + if not provided. + + """ self.inputs = inputs self.outputs = outputs + self.output_kms_key = output_kms_key super(ProcessingJob, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name) @classmethod @@ -586,7 +601,83 @@ def start_new(cls, processor, inputs, outputs, experiment_config): # Call sagemaker_session.process using the arguments dictionary. processor.sagemaker_session.process(**process_request_args) - return cls(processor.sagemaker_session, processor._current_job_name, inputs, outputs) + return cls( + processor.sagemaker_session, + processor._current_job_name, + inputs, + outputs, + processor.output_kms_key, + ) + + @classmethod + def from_processing_name(cls, sagemaker_session, processing_job_name): + """Initializes a Processing job from a Processing job name. + + Args: + processing_job_name (str): Name of the processing job. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using + the default AWS configuration chain. + + Returns: + sagemaker.processing.ProcessingJob: The instance of ProcessingJob created + using the current job name. + """ + job_desc = sagemaker_session.describe_processing_job(job_name=processing_job_name) + + return cls( + sagemaker_session=sagemaker_session, + job_name=processing_job_name, + inputs=[ + ProcessingInput( + source=processing_input["S3Input"]["S3Uri"], + destination=processing_input["S3Input"]["LocalPath"], + input_name=processing_input["InputName"], + s3_data_type=processing_input["S3Input"].get("S3DataType"), + s3_input_mode=processing_input["S3Input"].get("S3InputMode"), + s3_data_distribution_type=processing_input["S3Input"].get( + "S3DataDistributionType" + ), + s3_compression_type=processing_input["S3Input"].get("S3CompressionType"), + ) + for processing_input in job_desc["ProcessingInputs"] + ], + outputs=[ + ProcessingOutput( + source=job_desc["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][ + "LocalPath" + ], + destination=job_desc["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][ + "S3Uri" + ], + output_name=job_desc["ProcessingOutputConfig"]["Outputs"][0]["OutputName"], + ) + ], + output_kms_key=job_desc["ProcessingOutputConfig"].get("KmsKeyId"), + ) + + @classmethod + def from_processing_arn(cls, sagemaker_session, processing_job_arn): + """Initializes a Processing job from a Processing ARN. + + Args: + processing_job_arn (str): ARN of the processing job. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using + the default AWS configuration chain. + + Returns: + sagemaker.processing.ProcessingJob: The instance of ProcessingJob created + using the current job name. + """ + processing_job_name = processing_job_arn.split(":")[5][ + len("processing-job/") : + ] # This is necessary while the API only vends an arn. + return cls.from_processing_name( + sagemaker_session=sagemaker_session, processing_job_name=processing_job_name + ) def _is_local_channel(self, input_url): """Used for Local Mode. Not yet implemented. diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index e44c9f6ad7..07f86c998d 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -13,9 +13,15 @@ from __future__ import absolute_import import pytest -from mock import Mock, patch - -from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor, ScriptProcessor +from mock import Mock, patch, MagicMock + +from sagemaker.processing import ( + ProcessingInput, + ProcessingOutput, + Processor, + ScriptProcessor, + ProcessingJob, +) from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.network import NetworkConfig @@ -24,11 +30,51 @@ ROLE = "arn:aws:iam::012345678901:role/SageMakerRole" CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" +PROCESSING_JOB_DESCRIPTION = { + "ProcessingInputs": [ + { + "InputName": "my_dataset", + "S3Input": { + "S3Uri": "s3://path/to/my/dataset/census.csv", + "LocalPath": "/container/path/", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + }, + { + "InputName": "code", + "S3Input": { + "S3Uri": "mocked_s3_uri_from_upload_data", + "LocalPath": "/opt/ml/processing/input/code", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + }, + ], + "ProcessingOutputConfig": { + "Outputs": [ + { + "OutputName": "my_output", + "S3Output": { + "S3Uri": "s3://uri/", + "LocalPath": "/container/path/", + "S3UploadMode": "EndOfJob", + }, + } + ], + "KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + }, +} + @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) - session_mock = Mock( + session_mock = MagicMock( name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION, @@ -42,6 +88,9 @@ def sagemaker_session(): ) session_mock.download_data = Mock(name="download_data") session_mock.expand_role.return_value = ROLE + session_mock.describe_processing_job = MagicMock( + name="describe_processing_job", return_value=PROCESSING_JOB_DESCRIPTION + ) return session_mock @@ -388,6 +437,24 @@ def test_processor_with_all_parameters(sagemaker_session): sagemaker_session.process.assert_called_with(**expected_args) +def test_processing_job_from_processing_arn(sagemaker_session): + processing_job = ProcessingJob.from_processing_arn( + sagemaker_session=sagemaker_session, + processing_job_arn="arn:aws:sagemaker:dummy-region:dummy-account-number:processing-job/dummy-job-name", + ) + assert isinstance(processing_job, ProcessingJob) + assert [ + processing_input._to_request_dict() for processing_input in processing_job.inputs + ] == PROCESSING_JOB_DESCRIPTION["ProcessingInputs"] + assert [ + processing_output._to_request_dict() for processing_output in processing_job.outputs + ] == PROCESSING_JOB_DESCRIPTION["ProcessingOutputConfig"]["Outputs"] + assert ( + processing_job.output_kms_key + == PROCESSING_JOB_DESCRIPTION["ProcessingOutputConfig"]["KmsKeyId"] + ) + + def _get_script_processor(sagemaker_session): return ScriptProcessor( role=ROLE,