@@ -502,9 +502,24 @@ def _set_entrypoint(self, command, user_script_name):
502
502
class ProcessingJob (_Job ):
503
503
"""Provides functionality to start, describe, and stop processing jobs."""
504
504
505
- def __init__ (self , sagemaker_session , job_name , inputs , outputs ):
505
+ def __init__ (self , sagemaker_session , job_name , inputs , outputs , output_kms_key = None ):
506
+ """Initializes a Processing job.
507
+
508
+ Args:
509
+ sagemaker_session (sagemaker.session.Session): Session object which
510
+ manages interactions with Amazon SageMaker APIs and any other
511
+ AWS services needed. If not specified, one is created using
512
+ the default AWS configuration chain.
513
+ job_name (str): Name of the Processing job.
514
+ inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects.
515
+ outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects.
516
+ output_kms_key (str): The output kms key associated with the job. Defaults to None
517
+ if not provided.
518
+
519
+ """
506
520
self .inputs = inputs
507
521
self .outputs = outputs
522
+ self .output_kms_key = output_kms_key
508
523
super (ProcessingJob , self ).__init__ (sagemaker_session = sagemaker_session , job_name = job_name )
509
524
510
525
@classmethod
@@ -586,7 +601,83 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
586
601
# Call sagemaker_session.process using the arguments dictionary.
587
602
processor .sagemaker_session .process (** process_request_args )
588
603
589
- return cls (processor .sagemaker_session , processor ._current_job_name , inputs , outputs )
604
+ return cls (
605
+ processor .sagemaker_session ,
606
+ processor ._current_job_name ,
607
+ inputs ,
608
+ outputs ,
609
+ processor .output_kms_key ,
610
+ )
611
+
612
+ @classmethod
613
+ def from_processing_name (cls , sagemaker_session , processing_job_name ):
614
+ """Initializes a Processing job from a Processing job name.
615
+
616
+ Args:
617
+ processing_job_name (str): Name of the processing job.
618
+ sagemaker_session (sagemaker.session.Session): Session object which
619
+ manages interactions with Amazon SageMaker APIs and any other
620
+ AWS services needed. If not specified, one is created using
621
+ the default AWS configuration chain.
622
+
623
+ Returns:
624
+ sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
625
+ using the current job name.
626
+ """
627
+ job_desc = sagemaker_session .describe_processing_job (job_name = processing_job_name )
628
+
629
+ return cls (
630
+ sagemaker_session = sagemaker_session ,
631
+ job_name = processing_job_name ,
632
+ inputs = [
633
+ ProcessingInput (
634
+ source = processing_input ["S3Input" ]["S3Uri" ],
635
+ destination = processing_input ["S3Input" ]["LocalPath" ],
636
+ input_name = processing_input ["InputName" ],
637
+ s3_data_type = processing_input ["S3Input" ].get ("S3DataType" ),
638
+ s3_input_mode = processing_input ["S3Input" ].get ("S3InputMode" ),
639
+ s3_data_distribution_type = processing_input ["S3Input" ].get (
640
+ "S3DataDistributionType"
641
+ ),
642
+ s3_compression_type = processing_input ["S3Input" ].get ("S3CompressionType" ),
643
+ )
644
+ for processing_input in job_desc ["ProcessingInputs" ]
645
+ ],
646
+ outputs = [
647
+ ProcessingOutput (
648
+ source = job_desc ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ][
649
+ "LocalPath"
650
+ ],
651
+ destination = job_desc ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ][
652
+ "S3Uri"
653
+ ],
654
+ output_name = job_desc ["ProcessingOutputConfig" ]["Outputs" ][0 ]["OutputName" ],
655
+ )
656
+ ],
657
+ output_kms_key = job_desc ["ProcessingOutputConfig" ].get ("KmsKeyId" ),
658
+ )
659
+
660
+ @classmethod
661
+ def from_processing_arn (cls , sagemaker_session , processing_job_arn ):
662
+ """Initializes a Processing job from a Processing ARN.
663
+
664
+ Args:
665
+ processing_job_arn (str): ARN of the processing job.
666
+ sagemaker_session (sagemaker.session.Session): Session object which
667
+ manages interactions with Amazon SageMaker APIs and any other
668
+ AWS services needed. If not specified, one is created using
669
+ the default AWS configuration chain.
670
+
671
+ Returns:
672
+ sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
673
+ using the current job name.
674
+ """
675
+ processing_job_name = processing_job_arn .split (":" )[5 ][
676
+ len ("processing-job/" ) :
677
+ ] # This is necessary while the API only vends an arn.
678
+ return cls .from_processing_name (
679
+ sagemaker_session = sagemaker_session , processing_job_name = processing_job_name
680
+ )
590
681
591
682
def _is_local_channel (self , input_url ):
592
683
"""Used for Local Mode. Not yet implemented.
0 commit comments