Skip to content

Commit 3484073

Browse files
committed
fix formatting
1 parent 93718f9 commit 3484073

13 files changed

+246
-228
lines changed

src/sagemaker/estimator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from sagemaker.workflow.pipeline_context import (
7878
PipelineSession,
7979
runnable_by_pipeline,
80-
is_pipeline_entities
80+
is_pipeline_entities,
8181
)
8282

8383
logger = logging.getLogger(__name__)
@@ -1335,7 +1335,10 @@ def register(
13351335
@property
13361336
def model_data(self):
13371337
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1338-
if self.latest_training_job is not None and type(self.sagemaker_session) is not PipelineSession:
1338+
if (
1339+
self.latest_training_job is not None
1340+
and type(self.sagemaker_session) is not PipelineSession
1341+
):
13391342
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
13401343
TrainingJobName=self.latest_training_job.name
13411344
)["ModelArtifacts"]["S3ModelArtifacts"]
@@ -1762,7 +1765,7 @@ def start_new(cls, estimator, inputs, experiment_config):
17621765
"""
17631766
train_args = cls._get_train_args(estimator, inputs, experiment_config)
17641767
if type(estimator.sagemaker_session) is PipelineSession:
1765-
train_args['pipeline_session'] = estimator.sagemaker_session
1768+
train_args["pipeline_session"] = estimator.sagemaker_session
17661769

17671770
estimator.sagemaker_session.train(**train_args)
17681771

src/sagemaker/processing.py

-1
Original file line numberDiff line numberDiff line change
@@ -1759,4 +1759,3 @@ def _set_entrypoint(self, command, user_script_name):
17591759
)
17601760
)
17611761
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
1762-

src/sagemaker/transformer.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
from sagemaker.job import _Job
1919
from sagemaker.session import Session
20-
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline, is_pipeline_entities
20+
from sagemaker.workflow.pipeline_context import (
21+
PipelineSession,
22+
runnable_by_pipeline,
23+
is_pipeline_entities,
24+
)
2125
from sagemaker.utils import base_name_from_image, name_from_base
2226

2327

@@ -375,7 +379,7 @@ def start_new(
375379
model_client_config,
376380
)
377381
if type(transformer.sagemaker_session) is PipelineSession:
378-
transform_args['pipeline_session'] = transformer.sagemaker_session
382+
transform_args["pipeline_session"] = transformer.sagemaker_session
379383
transformer.sagemaker_session.transform(**transform_args)
380384

381385
return cls(transformer.sagemaker_session, transformer._current_job_name)

src/sagemaker/tuner.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from sagemaker.workflow.pipeline_context import (
4242
PipelineSession,
4343
runnable_by_pipeline,
44-
is_pipeline_entities
44+
is_pipeline_entities,
4545
)
4646

4747
from sagemaker.session import Session
@@ -366,9 +366,7 @@ def _prepare_static_hyperparameters(
366366
"""Prepare static hyperparameters for one estimator before tuning."""
367367
# Remove any hyperparameter that will be tuned
368368
static_hyperparameters = {
369-
str(k): str(v)
370-
if not is_pipeline_entities(v)
371-
else v
369+
str(k): str(v) if not is_pipeline_entities(v) else v
372370
for (k, v) in estimator.hyperparameters().items()
373371
}
374372
for hyperparameter_name in hyperparameter_ranges.keys():
@@ -475,7 +473,7 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
475473
self._validate_dict_argument(
476474
name="include_cls_metadata",
477475
value=include_cls_metadata if include_cls_metadata else {},
478-
allowed_keys=estimator_names
476+
allowed_keys=estimator_names,
479477
)
480478
self._validate_dict_argument(
481479
name="estimator_kwargs", value=estimator_kwargs, allowed_keys=estimator_names
@@ -1480,7 +1478,7 @@ def start_new(cls, tuner, inputs):
14801478
"""
14811479
tuner_args = cls._get_tuner_args(tuner, inputs)
14821480
if type(tuner.sagemaker_session) is PipelineSession:
1483-
tuner_args['pipeline_session'] = PipelineSession
1481+
tuner_args["pipeline_session"] = PipelineSession
14841482

14851483
tuner.sagemaker_session.create_tuning_job(**tuner_args)
14861484

src/sagemaker/workflow/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from sagemaker.workflow.step_collections import StepCollection
4242
from sagemaker.workflow.utilities import list_to_request
4343

44+
4445
@attr.s
4546
class Pipeline(Entity):
4647
"""Pipeline for workflow.

src/sagemaker/workflow/pipeline_context.py

+42-27
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The pipeline context for workflow"""
114
from __future__ import absolute_import
215

316
import warnings
@@ -11,20 +24,20 @@
1124

1225

1326
class PipelineSession(Session):
14-
"""Managing interactions with the Amazon SageMaker APIs and any other AWS services needed
15-
under SageMaker Model-Building Pipeline Context
27+
"""Managing interactions with SageMaker APIs and AWS services needed under SageMaker Model-Building Pipeline Context
28+
29+
This class inherits the SageMaker session, it provides convenient methods for manipulating entities
30+
and resources that Amazon SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
31+
When composing SageMaker Model-Building Pipeline, PipelineSession is recommended over
32+
regular SageMaker Session
33+
"""
1634

17-
This class inherits the SageMaker session, it provides convenient methods for manipulating entities
18-
and resources that Amazon SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
19-
When composing SageMaker Model-Building Pipeline, PipelineSession is recommended over
20-
regular SageMaker Session
21-
"""
2235
def __init__(
23-
self,
24-
boto_session=None,
25-
sagemaker_client=None,
26-
default_bucket=None,
27-
settings=SessionSettings(),
36+
self,
37+
boto_session=None,
38+
sagemaker_client=None,
39+
default_bucket=None,
40+
settings=SessionSettings(),
2841
):
2942
"""Initialize a ``PipelineSession``.
3043
@@ -66,16 +79,17 @@ def context(self, args: Dict):
6679
def runnable_by_pipeline(run_func):
6780
"""A convenient Decorator
6881
69-
This is a decorator designed to annotate, during pipeline session, the methods that downstream managed to
70-
1. preprocess user inputs, outputs, and configurations
71-
2. generate the create request
72-
3. start the job.
73-
For instance, `Processor.run`, `Estimator.fit`, or `Transformer.transform`. This decorator will
74-
essentially run 1, and capture the request shape from 2, then instead of starting a new job in 3, it will
75-
return request shape from 2 to `sagemaker.workflow.steps.Step`. The request shape will be used to construct
76-
the arguments needed to compose that particular step as part of the pipeline. The job will be started during
77-
pipeline execution.
82+
This is a decorator designed to annotate, during pipeline session, the methods that downstream managed to
83+
1. preprocess user inputs, outputs, and configurations
84+
2. generate the create request
85+
3. start the job.
86+
For instance, `Processor.run`, `Estimator.fit`, or `Transformer.transform`. This decorator will
87+
essentially run 1, and capture the request shape from 2, then instead of starting a new job in 3, it will
88+
return request shape from 2 to `sagemaker.workflow.steps.Step`. The request shape will be used to construct
89+
the arguments needed to compose that particular step as part of the pipeline. The job will be started during
90+
pipeline execution.
7891
"""
92+
7993
def wrapper(*args, **kwargs):
8094
if type(args[0].sagemaker_session) is PipelineSession:
8195
run_func_sig = inspect.signature(run_func)
@@ -85,29 +99,30 @@ def wrapper(*args, **kwargs):
8599
for i, (arg_name, param) in enumerate(run_func_sig.parameters.items()):
86100
if i >= len(arg_list):
87101
break
88-
if arg_name == 'wait':
102+
if arg_name == "wait":
89103
override_wait = True
90104
arg_list[i] = False
91-
elif arg_name == 'logs':
105+
elif arg_name == "logs":
92106
override_logs = True
93107
arg_list[i] = False
94108

95109
args = tuple(arg_list)
96110

97111
if not override_wait:
98-
kwargs['wait'] = False
112+
kwargs["wait"] = False
99113
if not override_logs:
100-
kwargs['logs'] = False
114+
kwargs["logs"] = False
101115

102116
warnings.warn(
103117
"Running within a PipelineSession, there will be No Wait, "
104118
"No Logs, and No Job being started.",
105-
UserWarning
119+
UserWarning,
106120
)
107121
run_func(*args, **kwargs)
108122
return args[0].sagemaker_session.context
109123
else:
110124
run_func(*args, **kwargs)
125+
111126
return wrapper
112127

113128

@@ -119,4 +134,4 @@ def is_pipeline_entities(obj: Any) -> bool:
119134
Returns:
120135
bool: if the given object is a pipeline Parameter, Expression, or Properties
121136
"""
122-
return isinstance(obj, (Parameter, Expression, Properties))
137+
return isinstance(obj, (Parameter, Expression, Properties))

src/sagemaker/workflow/steps.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,9 @@ def __init__(
285285
self.cache_config = cache_config
286286

287287
if self.cache_config:
288-
if (self.run_args and "ProfilerConfig" in self.run_args) or \
289-
(self.estimator is not None and not self.estimator.disable_profiler):
288+
if (self.run_args and "ProfilerConfig" in self.run_args) or (
289+
self.estimator is not None and not self.estimator.disable_profiler
290+
):
290291
msg = (
291292
"Profiling is enabled on the provided estimator. "
292293
"The default profiler rule includes a timestamp "
@@ -299,10 +300,10 @@ def __init__(
299300
if not self.run_args:
300301
warnings.warn(
301302
(
302-
"We are deprecating the instantiation of TrainingStep using \"estimator\"."
303-
"Instead, simply using \"run_args\"."
303+
'We are deprecating the instantiation of TrainingStep using "estimator".'
304+
'Instead, simply using "run_args".'
304305
),
305-
DeprecationWarning
306+
DeprecationWarning,
306307
)
307308

308309
@property
@@ -464,10 +465,10 @@ def __init__(
464465
raise ValueError("Inputs can't be None when transformer is given.")
465466
warnings.warn(
466467
(
467-
"We are deprecating the instantiation of TransformStep using \"transformer\"."
468-
"Instead, simply using \"run_args\"."
468+
'We are deprecating the instantiation of TransformStep using "transformer".'
469+
'Instead, simply using "run_args".'
469470
),
470-
DeprecationWarning
471+
DeprecationWarning,
471472
)
472473

473474
@property
@@ -493,7 +494,9 @@ def arguments(self) -> RequestType:
493494
model_client_config=self.inputs.model_client_config,
494495
experiment_config=dict(),
495496
)
496-
request_dict = self.transformer.sagemaker_session._get_transform_request(**transform_args)
497+
request_dict = self.transformer.sagemaker_session._get_transform_request(
498+
**transform_args
499+
)
497500

498501
request_dict.pop("TransformJobName", None)
499502
return request_dict
@@ -598,10 +601,10 @@ def __init__(
598601

599602
warnings.warn(
600603
(
601-
"We are deprecating the instantiation of ProcessingStep using \"processor\"."
602-
"Instead, simply using \"run_args\"."
604+
'We are deprecating the instantiation of ProcessingStep using "processor".'
605+
'Instead, simply using "run_args".'
603606
),
604-
DeprecationWarning
607+
DeprecationWarning,
605608
)
606609

607610
@property
@@ -738,10 +741,10 @@ def __init__(
738741
if not self.run_args:
739742
warnings.warn(
740743
(
741-
"We are deprecating the instantiation of TuningStep using \"tuner\"."
742-
"Instead, simply using \"run_args\"."
744+
'We are deprecating the instantiation of TuningStep using "tuner".'
745+
'Instead, simply using "run_args".'
743746
),
744-
DeprecationWarning
747+
DeprecationWarning,
745748
)
746749

747750
@property
@@ -808,5 +811,3 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
808811
"output/model.tar.gz",
809812
],
810813
)
811-
812-

src/sagemaker/workflow/utilities.py

-1
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,3 @@ def hash_file(path: str) -> str:
5858
md5.update(data)
5959

6060
return md5.hexdigest()
61-

tests/unit/sagemaker/workflow/test_pipeline_session.py

-2
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,3 @@ def test_pipeline_session_init(sagemaker_client_config, boto_session):
3333
assert sess.sagemaker_client is not None
3434
assert sess.default_bucket() is not None
3535
assert sess.context is None
36-
37-

0 commit comments

Comments
 (0)