Skip to content

Commit 405461c

Browse files
authored
feature: create ProcessingJob from ARN and from name (#1186)
* Add ProcessingJob.from_processing_arn * Add ProcessingJob.from_processing_name * Add describe_processing_job to mock sagemaker_session in unit tests * Add unit test to cover ProcessingJob.from_processing_arn and from_processing_name * Update docstring * Add output_kms_key to super call in MonitoringExecution.__init__ * Make output_kms_key optional in BaseliningJob.__init__ * Fix mocked describe_processing_job * Add more stringent assert statements to test_processing_job_from_processing_arn * Default output_kms_key to None
1 parent ce45f26 commit 405461c

File tree

3 files changed

+181
-12
lines changed

3 files changed

+181
-12
lines changed

src/sagemaker/model_monitor/model_monitoring.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,7 @@ def _get_default_image_uri(region):
17611761
class BaseliningJob(ProcessingJob):
17621762
"""Provides functionality to retrieve baseline-specific files output from baselining job."""
17631763

1764-
def __init__(self, sagemaker_session, job_name, inputs, outputs):
1764+
def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None):
17651765
"""Initializes a Baselining job that tracks a baselining job kicked off by the suggest
17661766
workflow.
17671767
@@ -1773,12 +1773,18 @@ def __init__(self, sagemaker_session, job_name, inputs, outputs):
17731773
job_name (str): Name of the Amazon SageMaker Model Monitoring Baselining Job.
17741774
inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects.
17751775
outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects.
1776+
output_kms_key (str): The output kms key associated with the job. Defaults to None
1777+
if not provided.
17761778
17771779
"""
17781780
self.inputs = inputs
17791781
self.outputs = outputs
17801782
super(BaseliningJob, self).__init__(
1781-
sagemaker_session=sagemaker_session, job_name=job_name, inputs=inputs, outputs=outputs
1783+
sagemaker_session=sagemaker_session,
1784+
job_name=job_name,
1785+
inputs=inputs,
1786+
outputs=outputs,
1787+
output_kms_key=output_kms_key,
17821788
)
17831789

17841790
@classmethod
@@ -1799,6 +1805,7 @@ def from_processing_job(cls, processing_job):
17991805
processing_job.job_name,
18001806
processing_job.inputs,
18011807
processing_job.outputs,
1808+
processing_job.output_kms_key,
18021809
)
18031810

18041811
def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_key=None):
@@ -1881,7 +1888,7 @@ class MonitoringExecution(ProcessingJob):
18811888
executions
18821889
"""
18831890

1884-
def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key):
1891+
def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key=None):
18851892
"""Initializes a MonitoringExecution job that tracks a monitoring execution kicked off by
18861893
an Amazon SageMaker Model Monitoring Schedule.
18871894
@@ -1893,13 +1900,17 @@ def __init__(self, sagemaker_session, job_name, inputs, output, output_kms_key):
18931900
job_name (str): The name of the monitoring execution job.
18941901
output (sagemaker.Processing.ProcessingOutput): The output associated with the
18951902
monitoring execution.
1896-
output_kms_key (str): The output kms key associated with the job.
1903+
output_kms_key (str): The output kms key associated with the job. Defaults to None
1904+
if not provided.
18971905
18981906
"""
18991907
self.output = output
1900-
self.output_kms_key = output_kms_key
19011908
super(MonitoringExecution, self).__init__(
1902-
sagemaker_session=sagemaker_session, job_name=job_name, inputs=inputs, outputs=[output]
1909+
sagemaker_session=sagemaker_session,
1910+
job_name=job_name,
1911+
inputs=inputs,
1912+
outputs=[output],
1913+
output_kms_key=output_kms_key,
19031914
)
19041915

19051916
@classmethod

src/sagemaker/processing.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,24 @@ def _set_entrypoint(self, command, user_script_name):
502502
class ProcessingJob(_Job):
503503
"""Provides functionality to start, describe, and stop processing jobs."""
504504

505-
def __init__(self, sagemaker_session, job_name, inputs, outputs):
505+
def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None):
506+
"""Initializes a Processing job.
507+
508+
Args:
509+
sagemaker_session (sagemaker.session.Session): Session object which
510+
manages interactions with Amazon SageMaker APIs and any other
511+
AWS services needed. If not specified, one is created using
512+
the default AWS configuration chain.
513+
job_name (str): Name of the Processing job.
514+
inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects.
515+
outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects.
516+
output_kms_key (str): The output kms key associated with the job. Defaults to None
517+
if not provided.
518+
519+
"""
506520
self.inputs = inputs
507521
self.outputs = outputs
522+
self.output_kms_key = output_kms_key
508523
super(ProcessingJob, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name)
509524

510525
@classmethod
@@ -586,7 +601,83 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
586601
# Call sagemaker_session.process using the arguments dictionary.
587602
processor.sagemaker_session.process(**process_request_args)
588603

589-
return cls(processor.sagemaker_session, processor._current_job_name, inputs, outputs)
604+
return cls(
605+
processor.sagemaker_session,
606+
processor._current_job_name,
607+
inputs,
608+
outputs,
609+
processor.output_kms_key,
610+
)
611+
612+
@classmethod
613+
def from_processing_name(cls, sagemaker_session, processing_job_name):
614+
"""Initializes a Processing job from a Processing job name.
615+
616+
Args:
617+
processing_job_name (str): Name of the processing job.
618+
sagemaker_session (sagemaker.session.Session): Session object which
619+
manages interactions with Amazon SageMaker APIs and any other
620+
AWS services needed. If not specified, one is created using
621+
the default AWS configuration chain.
622+
623+
Returns:
624+
sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
625+
using the current job name.
626+
"""
627+
job_desc = sagemaker_session.describe_processing_job(job_name=processing_job_name)
628+
629+
return cls(
630+
sagemaker_session=sagemaker_session,
631+
job_name=processing_job_name,
632+
inputs=[
633+
ProcessingInput(
634+
source=processing_input["S3Input"]["S3Uri"],
635+
destination=processing_input["S3Input"]["LocalPath"],
636+
input_name=processing_input["InputName"],
637+
s3_data_type=processing_input["S3Input"].get("S3DataType"),
638+
s3_input_mode=processing_input["S3Input"].get("S3InputMode"),
639+
s3_data_distribution_type=processing_input["S3Input"].get(
640+
"S3DataDistributionType"
641+
),
642+
s3_compression_type=processing_input["S3Input"].get("S3CompressionType"),
643+
)
644+
for processing_input in job_desc["ProcessingInputs"]
645+
],
646+
outputs=[
647+
ProcessingOutput(
648+
source=job_desc["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][
649+
"LocalPath"
650+
],
651+
destination=job_desc["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][
652+
"S3Uri"
653+
],
654+
output_name=job_desc["ProcessingOutputConfig"]["Outputs"][0]["OutputName"],
655+
)
656+
],
657+
output_kms_key=job_desc["ProcessingOutputConfig"].get("KmsKeyId"),
658+
)
659+
660+
@classmethod
661+
def from_processing_arn(cls, sagemaker_session, processing_job_arn):
662+
"""Initializes a Processing job from a Processing ARN.
663+
664+
Args:
665+
processing_job_arn (str): ARN of the processing job.
666+
sagemaker_session (sagemaker.session.Session): Session object which
667+
manages interactions with Amazon SageMaker APIs and any other
668+
AWS services needed. If not specified, one is created using
669+
the default AWS configuration chain.
670+
671+
Returns:
672+
sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
673+
using the current job name.
674+
"""
675+
processing_job_name = processing_job_arn.split(":")[5][
676+
len("processing-job/") :
677+
] # This is necessary while the API only vends an arn.
678+
return cls.from_processing_name(
679+
sagemaker_session=sagemaker_session, processing_job_name=processing_job_name
680+
)
590681

591682
def _is_local_channel(self, input_url):
592683
"""Used for Local Mode. Not yet implemented.

tests/unit/test_processing.py

+71-4
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from mock import Mock, patch
17-
18-
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor, ScriptProcessor
16+
from mock import Mock, patch, MagicMock
17+
18+
from sagemaker.processing import (
19+
ProcessingInput,
20+
ProcessingOutput,
21+
Processor,
22+
ScriptProcessor,
23+
ProcessingJob,
24+
)
1925
from sagemaker.sklearn.processing import SKLearnProcessor
2026
from sagemaker.network import NetworkConfig
2127

@@ -24,11 +30,51 @@
2430
ROLE = "arn:aws:iam::012345678901:role/SageMakerRole"
2531
CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"
2632

33+
PROCESSING_JOB_DESCRIPTION = {
34+
"ProcessingInputs": [
35+
{
36+
"InputName": "my_dataset",
37+
"S3Input": {
38+
"S3Uri": "s3://path/to/my/dataset/census.csv",
39+
"LocalPath": "/container/path/",
40+
"S3DataType": "S3Prefix",
41+
"S3InputMode": "File",
42+
"S3DataDistributionType": "FullyReplicated",
43+
"S3CompressionType": "None",
44+
},
45+
},
46+
{
47+
"InputName": "code",
48+
"S3Input": {
49+
"S3Uri": "mocked_s3_uri_from_upload_data",
50+
"LocalPath": "/opt/ml/processing/input/code",
51+
"S3DataType": "S3Prefix",
52+
"S3InputMode": "File",
53+
"S3DataDistributionType": "FullyReplicated",
54+
"S3CompressionType": "None",
55+
},
56+
},
57+
],
58+
"ProcessingOutputConfig": {
59+
"Outputs": [
60+
{
61+
"OutputName": "my_output",
62+
"S3Output": {
63+
"S3Uri": "s3://uri/",
64+
"LocalPath": "/container/path/",
65+
"S3UploadMode": "EndOfJob",
66+
},
67+
}
68+
],
69+
"KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
70+
},
71+
}
72+
2773

2874
@pytest.fixture()
2975
def sagemaker_session():
3076
boto_mock = Mock(name="boto_session", region_name=REGION)
31-
session_mock = Mock(
77+
session_mock = MagicMock(
3278
name="sagemaker_session",
3379
boto_session=boto_mock,
3480
boto_region_name=REGION,
@@ -42,6 +88,9 @@ def sagemaker_session():
4288
)
4389
session_mock.download_data = Mock(name="download_data")
4490
session_mock.expand_role.return_value = ROLE
91+
session_mock.describe_processing_job = MagicMock(
92+
name="describe_processing_job", return_value=PROCESSING_JOB_DESCRIPTION
93+
)
4594
return session_mock
4695

4796

@@ -388,6 +437,24 @@ def test_processor_with_all_parameters(sagemaker_session):
388437
sagemaker_session.process.assert_called_with(**expected_args)
389438

390439

440+
def test_processing_job_from_processing_arn(sagemaker_session):
441+
processing_job = ProcessingJob.from_processing_arn(
442+
sagemaker_session=sagemaker_session,
443+
processing_job_arn="arn:aws:sagemaker:dummy-region:dummy-account-number:processing-job/dummy-job-name",
444+
)
445+
assert isinstance(processing_job, ProcessingJob)
446+
assert [
447+
processing_input._to_request_dict() for processing_input in processing_job.inputs
448+
] == PROCESSING_JOB_DESCRIPTION["ProcessingInputs"]
449+
assert [
450+
processing_output._to_request_dict() for processing_output in processing_job.outputs
451+
] == PROCESSING_JOB_DESCRIPTION["ProcessingOutputConfig"]["Outputs"]
452+
assert (
453+
processing_job.output_kms_key
454+
== PROCESSING_JOB_DESCRIPTION["ProcessingOutputConfig"]["KmsKeyId"]
455+
)
456+
457+
391458
def _get_script_processor(sagemaker_session):
392459
return ScriptProcessor(
393460
role=ROLE,

0 commit comments

Comments
 (0)