1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """The pipeline context for workflow"""
1
14
from __future__ import absolute_import
2
15
3
16
import warnings
11
24
12
25
13
26
class PipelineSession (Session ):
14
- """Managing interactions with the Amazon SageMaker APIs and any other AWS services needed
15
- under SageMaker Model-Building Pipeline Context
27
+ """Managing interactions with SageMaker APIs and AWS services needed under SageMaker Model-Building Pipeline Context
28
+
29
+ This class inherits the SageMaker session, it provides convenient methods for manipulating entities
30
+ and resources that Amazon SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
31
+ When composing SageMaker Model-Building Pipeline, PipelineSession is recommended over
32
+ regular SageMaker Session
33
+ """
16
34
17
- This class inherits the SageMaker session, it provides convenient methods for manipulating entities
18
- and resources that Amazon SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
19
- When composing SageMaker Model-Building Pipeline, PipelineSession is recommended over
20
- regular SageMaker Session
21
- """
22
35
def __init__ (
23
- self ,
24
- boto_session = None ,
25
- sagemaker_client = None ,
26
- default_bucket = None ,
27
- settings = SessionSettings (),
36
+ self ,
37
+ boto_session = None ,
38
+ sagemaker_client = None ,
39
+ default_bucket = None ,
40
+ settings = SessionSettings (),
28
41
):
29
42
"""Initialize a ``PipelineSession``.
30
43
@@ -66,16 +79,17 @@ def context(self, args: Dict):
66
79
def runnable_by_pipeline (run_func ):
67
80
"""A convenient Decorator
68
81
69
- This is a decorator designed to annotate, during pipeline session, the methods that downstream managed to
70
- 1. preprocess user inputs, outputs, and configurations
71
- 2. generate the create request
72
- 3. start the job.
73
- For instance, `Processor.run`, `Estimator.fit`, or `Transformer.transform`. This decorator will
74
- essentially run 1, and capture the request shape from 2, then instead of starting a new job in 3, it will
75
- return request shape from 2 to `sagemaker.workflow.steps.Step`. The request shape will be used to construct
76
- the arguments needed to compose that particular step as part of the pipeline. The job will be started during
77
- pipeline execution.
82
+ This is a decorator designed to annotate, during pipeline session, the methods that downstream managed to
83
+ 1. preprocess user inputs, outputs, and configurations
84
+ 2. generate the create request
85
+ 3. start the job.
86
+ For instance, `Processor.run`, `Estimator.fit`, or `Transformer.transform`. This decorator will
87
+ essentially run 1, and capture the request shape from 2, then instead of starting a new job in 3, it will
88
+ return request shape from 2 to `sagemaker.workflow.steps.Step`. The request shape will be used to construct
89
+ the arguments needed to compose that particular step as part of the pipeline. The job will be started during
90
+ pipeline execution.
78
91
"""
92
+
79
93
def wrapper (* args , ** kwargs ):
80
94
if type (args [0 ].sagemaker_session ) is PipelineSession :
81
95
run_func_sig = inspect .signature (run_func )
@@ -85,29 +99,30 @@ def wrapper(*args, **kwargs):
85
99
for i , (arg_name , param ) in enumerate (run_func_sig .parameters .items ()):
86
100
if i >= len (arg_list ):
87
101
break
88
- if arg_name == ' wait' :
102
+ if arg_name == " wait" :
89
103
override_wait = True
90
104
arg_list [i ] = False
91
- elif arg_name == ' logs' :
105
+ elif arg_name == " logs" :
92
106
override_logs = True
93
107
arg_list [i ] = False
94
108
95
109
args = tuple (arg_list )
96
110
97
111
if not override_wait :
98
- kwargs [' wait' ] = False
112
+ kwargs [" wait" ] = False
99
113
if not override_logs :
100
- kwargs [' logs' ] = False
114
+ kwargs [" logs" ] = False
101
115
102
116
warnings .warn (
103
117
"Running within a PipelineSession, there will be No Wait, "
104
118
"No Logs, and No Job being started." ,
105
- UserWarning
119
+ UserWarning ,
106
120
)
107
121
run_func (* args , ** kwargs )
108
122
return args [0 ].sagemaker_session .context
109
123
else :
110
124
run_func (* args , ** kwargs )
125
+
111
126
return wrapper
112
127
113
128
@@ -119,4 +134,4 @@ def is_pipeline_entities(obj: Any) -> bool:
119
134
Returns:
120
135
bool: if the given object is a pipeline Parameter, Expression, or Properties
121
136
"""
122
- return isinstance (obj , (Parameter , Expression , Properties ))
137
+ return isinstance (obj , (Parameter , Expression , Properties ))
0 commit comments