Skip to content

feature: create ProcessingJob from ARN and from name #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 19, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,7 @@ def _get_default_image_uri(region):
class BaseliningJob(ProcessingJob):
"""Provides functionality to retrieve baseline-specific files output from baselining job."""

def __init__(self, sagemaker_session, job_name, inputs, outputs):
def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None):
"""Initializes a Baselining job that tracks a baselining job kicked off by the suggest
workflow.

Expand All @@ -1773,12 +1773,17 @@ def __init__(self, sagemaker_session, job_name, inputs, outputs):
job_name (str): Name of the Amazon SageMaker Model Monitoring Baselining Job.
inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects.
outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects.
output_kms_key (str): The output kms key associated with the job.

"""
self.inputs = inputs
self.outputs = outputs
super(BaseliningJob, self).__init__(
sagemaker_session=sagemaker_session, job_name=job_name, inputs=inputs, outputs=outputs
sagemaker_session=sagemaker_session,
job_name=job_name,
inputs=inputs,
outputs=outputs,
output_kms_key=output_kms_key,
)

@classmethod
Expand All @@ -1799,6 +1804,7 @@ def from_processing_job(cls, processing_job):
processing_job.job_name,
processing_job.inputs,
processing_job.outputs,
processing_job.output_kms_key,
)

def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_key=None):
Expand Down Expand Up @@ -1897,9 +1903,12 @@ def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key):

"""
self.output = output
self.output_kms_key = output_kms_key
super(MonitoringExecution, self).__init__(
sagemaker_session=sagemaker_session, job_name=job_name, inputs=inputs, outputs=[output]
sagemaker_session=sagemaker_session,
job_name=job_name,
inputs=inputs,
outputs=[output],
output_kms_key=output_kms_key,
)

@classmethod
Expand Down
81 changes: 79 additions & 2 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,10 @@ def _set_entrypoint(self, command, user_script_name):
class ProcessingJob(_Job):
"""Provides functionality to start, describe, and stop processing jobs."""

def __init__(self, sagemaker_session, job_name, inputs, outputs):
def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key):
self.inputs = inputs
self.outputs = outputs
self.output_kms_key = output_kms_key
super(ProcessingJob, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name)

@classmethod
Expand Down Expand Up @@ -586,7 +587,83 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
# Call sagemaker_session.process using the arguments dictionary.
processor.sagemaker_session.process(**process_request_args)

return cls(processor.sagemaker_session, processor._current_job_name, inputs, outputs)
return cls(
processor.sagemaker_session,
processor._current_job_name,
inputs,
outputs,
processor.output_kms_key,
)

@classmethod
def from_processing_name(cls, sagemaker_session, processing_job_name):
"""Initializes a Processing job from a Processing job name.

Args:
processing_job_name (str): Name of the processing job.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using
the default AWS configuration chain.

Returns:
sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
using the current job name.
"""
job_desc = sagemaker_session.describe_processing_job(job_name=processing_job_name)

return cls(
sagemaker_session=sagemaker_session,
job_name=processing_job_name,
inputs=[
ProcessingInput(
source=processing_input["S3Input"]["S3Uri"],
destination=processing_input["S3Input"]["LocalPath"],
input_name=processing_input["InputName"],
s3_data_type=processing_input["S3Input"].get("S3DataType"),
s3_input_mode=processing_input["S3Input"].get("S3InputMode"),
s3_data_distribution_type=processing_input["S3Input"].get(
"S3DataDistributionType"
),
s3_compression_type=processing_input["S3Input"].get("S3CompressionType"),
)
for processing_input in job_desc["ProcessingInputs"]
],
outputs=[
ProcessingOutput(
source=job_desc["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][
"LocalPath"
],
destination=job_desc["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][
"S3Uri"
],
output_name=job_desc["ProcessingOutputConfig"]["Outputs"][0]["OutputName"],
)
],
output_kms_key=job_desc["ProcessingOutputConfig"].get("KmsKeyId"),
)

@classmethod
def from_processing_arn(cls, sagemaker_session, processing_job_arn):
"""Initializes a Processing job from a Processing ARN.

Args:
processing_job_arn (str): ARN of the processing job.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using
the default AWS configuration chain.

Returns:
sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
using the current job name.
"""
processing_job_name = processing_job_arn.split(":")[5][
len("processing-job/") :
] # This is necessary while the API only vends an arn.
return cls.from_processing_name(
sagemaker_session=sagemaker_session, processing_job_name=processing_job_name
)

def _is_local_channel(self, input_url):
"""Used for Local Mode. Not yet implemented.
Expand Down
75 changes: 71 additions & 4 deletions tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
from __future__ import absolute_import

import pytest
from mock import Mock, patch

from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor, ScriptProcessor
from mock import Mock, patch, MagicMock

from sagemaker.processing import (
ProcessingInput,
ProcessingOutput,
Processor,
ScriptProcessor,
ProcessingJob,
)
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.network import NetworkConfig

Expand All @@ -24,11 +30,51 @@
ROLE = "arn:aws:iam::012345678901:role/SageMakerRole"
CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"

PROCESSING_JOB_DESCRIPTION = {
"ProcessingInputs": [
{
"InputName": "my_dataset",
"S3Input": {
"S3Uri": "s3://path/to/my/dataset/census.csv",
"LocalPath": "/container/path/",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3DataDistributionType": "FullyReplicated",
"S3CompressionType": "None",
},
},
{
"InputName": "code",
"S3Input": {
"S3Uri": "mocked_s3_uri_from_upload_data",
"LocalPath": "/opt/ml/processing/input/code",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3DataDistributionType": "FullyReplicated",
"S3CompressionType": "None",
},
},
],
"ProcessingOutputConfig": {
"Outputs": [
{
"OutputName": "my_output",
"S3Output": {
"S3Uri": "s3://uri/",
"LocalPath": "/container/path/",
"S3UploadMode": "EndOfJob",
},
}
],
"KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
},
}


@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
session_mock = Mock(
session_mock = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=REGION,
Expand All @@ -42,6 +88,9 @@ def sagemaker_session():
)
session_mock.download_data = Mock(name="download_data")
session_mock.expand_role.return_value = ROLE
session_mock.describe_processing_job = MagicMock(
name="describe_processing_job", return_value=PROCESSING_JOB_DESCRIPTION
)
return session_mock


Expand Down Expand Up @@ -388,6 +437,24 @@ def test_processor_with_all_parameters(sagemaker_session):
sagemaker_session.process.assert_called_with(**expected_args)


def test_processing_job_from_processing_arn(sagemaker_session):
processing_job = ProcessingJob.from_processing_arn(
sagemaker_session=sagemaker_session,
processing_job_arn="arn:aws:sagemaker:dummy-region:dummy-account-number:processing-job/dummy-job-name",
)
assert isinstance(processing_job, ProcessingJob)
assert [
processing_input._to_request_dict() for processing_input in processing_job.inputs
] == PROCESSING_JOB_DESCRIPTION["ProcessingInputs"]
assert [
processing_output._to_request_dict() for processing_output in processing_job.outputs
] == PROCESSING_JOB_DESCRIPTION["ProcessingOutputConfig"]["Outputs"]
assert (
processing_job.output_kms_key
== PROCESSING_JOB_DESCRIPTION["ProcessingOutputConfig"]["KmsKeyId"]
)


def _get_script_processor(sagemaker_session):
return ScriptProcessor(
role=ROLE,
Expand Down