Skip to content

Commit bf06963

Browse files
authored
Merge branch 'master' into feat/jumpstart-model-table-update
2 parents d830446 + dcec9bb commit bf06963

21 files changed

+1479
-141
lines changed

CHANGELOG.md

+21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
# Changelog
22

3+
## v2.88.1 (2022-04-27)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* Add encryption setting to tar_and_upload_dir method
8+
9+
## v2.88.0 (2022-04-26)
10+
11+
### Features
12+
13+
* jumpstart notebook utils -- list model ids, scripts, tasks, frameworks
14+
15+
### Bug Fixes and Other Changes
16+
17+
* local mode printing of credentials during docker login closes #2180
18+
* disable endpoint context test
19+
20+
### Documentation Changes
21+
22+
* sm model parallel 1.8.0 release notes
23+
324
## v2.87.0 (2022-04-20)
425

526
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.87.1.dev0
1+
2.88.2.dev0

src/sagemaker/amazon/amazon_estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.estimator import EstimatorBase, _TrainingJob
2828
from sagemaker.inputs import FileSystemInput, TrainingInput
2929
from sagemaker.utils import sagemaker_timestamp
30+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -192,6 +193,7 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
192193
self.feature_dim = feature_dim
193194
self.mini_batch_size = mini_batch_size
194195

196+
@runnable_by_pipeline
195197
def fit(
196198
self,
197199
records,

src/sagemaker/clarify.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,8 @@ def _run(
803803
output_name="analysis_result",
804804
s3_upload_mode="EndOfJob",
805805
)
806-
super().run(
806+
807+
return super().run(
807808
inputs=[data_input, config_input],
808809
outputs=[result_output],
809810
wait=wait,
@@ -871,7 +872,7 @@ def run_pre_training_bias(
871872
job_name = utils.name_from_base(self.job_name_prefix)
872873
else:
873874
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
874-
self._run(
875+
return self._run(
875876
data_config,
876877
analysis_config,
877878
wait,
@@ -957,7 +958,7 @@ def run_post_training_bias(
957958
job_name = utils.name_from_base(self.job_name_prefix)
958959
else:
959960
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
960-
self._run(
961+
return self._run(
961962
data_config,
962963
analysis_config,
963964
wait,
@@ -1060,7 +1061,7 @@ def run_bias(
10601061
job_name = utils.name_from_base(self.job_name_prefix)
10611062
else:
10621063
job_name = utils.name_from_base("Clarify-Bias")
1063-
self._run(
1064+
return self._run(
10641065
data_config,
10651066
analysis_config,
10661067
wait,
@@ -1167,7 +1168,7 @@ def run_explainability(
11671168
job_name = utils.name_from_base(self.job_name_prefix)
11681169
else:
11691170
job_name = utils.name_from_base("Clarify-Explainability")
1170-
self._run(
1171+
return self._run(
11711172
data_config,
11721173
analysis_config,
11731174
wait,

src/sagemaker/estimator.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import json
1717
import logging
1818
import os
19-
from typing import Any, Dict
2019
import uuid
2120
from abc import ABCMeta, abstractmethod
21+
from typing import Any, Dict
2222

2323
from six import string_types, with_metaclass
2424
from six.moves.urllib.parse import urlparse
@@ -75,6 +75,10 @@
7575
name_from_base,
7676
)
7777
from sagemaker.workflow import is_pipeline_variable
78+
from sagemaker.workflow.pipeline_context import (
79+
PipelineSession,
80+
runnable_by_pipeline,
81+
)
7882

7983
logger = logging.getLogger(__name__)
8084

@@ -721,6 +725,7 @@ def _stage_user_code_in_s3(self) -> str:
721725
dependencies=self.dependencies,
722726
kms_key=kms_key,
723727
s3_resource=self.sagemaker_session.s3_resource,
728+
settings=self.sagemaker_session.settings,
724729
)
725730

726731
def _prepare_rules(self):
@@ -896,6 +901,7 @@ def latest_job_profiler_artifacts_path(self):
896901
)
897902
return None
898903

904+
@runnable_by_pipeline
899905
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
900906
"""Train a model using the input training dataset.
901907
@@ -1341,7 +1347,9 @@ def register(
13411347
@property
13421348
def model_data(self):
13431349
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1344-
if self.latest_training_job is not None:
1350+
if self.latest_training_job is not None and not isinstance(
1351+
self.sagemaker_session, PipelineSession
1352+
):
13451353
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
13461354
TrainingJobName=self.latest_training_job.name
13471355
)["ModelArtifacts"]["S3ModelArtifacts"]
@@ -1767,6 +1775,7 @@ def start_new(cls, estimator, inputs, experiment_config):
17671775
all information about the started training job.
17681776
"""
17691777
train_args = cls._get_train_args(estimator, inputs, experiment_config)
1778+
17701779
estimator.sagemaker_session.train(**train_args)
17711780

17721781
return cls(estimator.sagemaker_session, estimator._current_job_name)

src/sagemaker/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
447447
script=self.entry_point,
448448
directory=self.source_dir,
449449
dependencies=self.dependencies,
450+
settings=self.sagemaker_session.settings,
450451
)
451452

452453
if repack and self.model_data is not None and self.entry_point is not None:

src/sagemaker/processing.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,13 @@
2828

2929
from six.moves.urllib.parse import urlparse
3030
from six.moves.urllib.request import url2pathname
31-
3231
from sagemaker import s3
3332
from sagemaker.job import _Job
3433
from sagemaker.local import LocalSession
3534
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3635
from sagemaker.session import Session
3736
from sagemaker.workflow import is_pipeline_variable
38-
from sagemaker.workflow.properties import Properties
39-
from sagemaker.workflow.parameters import Parameter
40-
from sagemaker.workflow.entities import Expression
37+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
4138
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4239
from sagemaker.apiutils._base_types import ApiObject
4340
from sagemaker.s3 import S3Uploader
@@ -133,6 +130,7 @@ def __init__(
133130

134131
self.sagemaker_session = sagemaker_session or Session()
135132

133+
@runnable_by_pipeline
136134
def run(
137135
self,
138136
inputs=None,
@@ -314,10 +312,10 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
314312
if file_input.input_name is None:
315313
file_input.input_name = "input-{}".format(count)
316314

317-
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
315+
if is_pipeline_variable(file_input.source) or file_input.dataset_definition:
318316
normalized_inputs.append(file_input)
319317
continue
320-
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
318+
if is_pipeline_variable(file_input.s3_input.s3_uri):
321319
normalized_inputs.append(file_input)
322320
continue
323321
# If the source is a local path, upload it to S3
@@ -367,7 +365,7 @@ def _normalize_outputs(self, outputs=None):
367365
# Generate a name for the ProcessingOutput if it doesn't have one.
368366
if output.output_name is None:
369367
output.output_name = "output-{}".format(count)
370-
if isinstance(output.destination, (Parameter, Expression, Properties)):
368+
if is_pipeline_variable(output.destination):
371369
normalized_outputs.append(output)
372370
continue
373371
# If the output's destination is not an s3_uri, create one.
@@ -497,6 +495,7 @@ def get_run_args(
497495
"""
498496
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)
499497

498+
@runnable_by_pipeline
500499
def run(
501500
self,
502501
code,
@@ -1600,7 +1599,7 @@ def run( # type: ignore[override]
16001599
)
16011600

16021601
# Submit a processing job.
1603-
super().run(
1602+
return super().run(
16041603
code=s3_runproc_sh,
16051604
inputs=inputs,
16061605
outputs=outputs,

src/sagemaker/session.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import re
2020
import sys
2121
import time
22+
import typing
2223
import warnings
2324
from typing import List, Dict, Any, Sequence
2425

@@ -551,7 +552,6 @@ def train( # noqa: C901
551552
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
552553
* max_retry_attsmpts (int): Number of times a job should be retried.
553554
The key in RetryStrategy is 'MaxRetryAttempts'.
554-
555555
Returns:
556556
str: ARN of the training job, if it is created.
557557
"""
@@ -585,9 +585,13 @@ def train( # noqa: C901
585585
environment=environment,
586586
retry_strategy=retry_strategy,
587587
)
588-
LOGGER.info("Creating training-job with name: %s", job_name)
589-
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
590-
self.sagemaker_client.create_training_job(**train_request)
588+
589+
def submit(request):
590+
LOGGER.info("Creating training-job with name: %s", job_name)
591+
LOGGER.debug("train request: %s", json.dumps(request, indent=4))
592+
self.sagemaker_client.create_training_job(**request)
593+
594+
self._intercept_create_request(train_request, submit)
591595

592596
def _get_train_request( # noqa: C901
593597
self,
@@ -912,9 +916,13 @@ def process(
912916
tags=tags,
913917
experiment_config=experiment_config,
914918
)
915-
LOGGER.info("Creating processing-job with name %s", job_name)
916-
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
917-
self.sagemaker_client.create_processing_job(**process_request)
919+
920+
def submit(request):
921+
LOGGER.info("Creating processing-job with name %s", job_name)
922+
LOGGER.debug("process request: %s", json.dumps(request, indent=4))
923+
self.sagemaker_client.create_processing_job(**request)
924+
925+
self._intercept_create_request(process_request, submit)
918926

919927
def _get_process_request(
920928
self,
@@ -2086,9 +2094,12 @@ def create_tuning_job(
20862094
tags=tags,
20872095
)
20882096

2089-
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2090-
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
2091-
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
2097+
def submit(request):
2098+
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2099+
LOGGER.debug("tune request: %s", json.dumps(request, indent=4))
2100+
self.sagemaker_client.create_hyper_parameter_tuning_job(**request)
2101+
2102+
self._intercept_create_request(tune_request, submit)
20922103

20932104
def _get_tuning_request(
20942105
self,
@@ -2553,9 +2564,12 @@ def transform(
25532564
model_client_config=model_client_config,
25542565
)
25552566

2556-
LOGGER.info("Creating transform job with name: %s", job_name)
2557-
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
2558-
self.sagemaker_client.create_transform_job(**transform_request)
2567+
def submit(request):
2568+
LOGGER.info("Creating transform job with name: %s", job_name)
2569+
LOGGER.debug("Transform request: %s", json.dumps(request, indent=4))
2570+
self.sagemaker_client.create_transform_job(**request)
2571+
2572+
self._intercept_create_request(transform_request, submit)
25592573

25602574
def _create_model_request(
25612575
self,
@@ -4161,6 +4175,18 @@ def account_id(self) -> str:
41614175
)
41624176
return sts_client.get_caller_identity()["Account"]
41634177

4178+
def _intercept_create_request(self, request: typing.Dict, create):
4179+
"""This function intercepts the create job request.
4180+
4181+
PipelineSession inherits this Session class and will override
4182+
this function to intercept the create request.
4183+
4184+
Args:
4185+
request (dict): the create job request
4186+
create (functor): a functor calls the sagemaker client create method
4187+
"""
4188+
create(request)
4189+
41644190

41654191
def get_model_package_args(
41664192
content_types,

src/sagemaker/spark/processing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run(
249249
"""
250250
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251251

252-
super().run(
252+
return super().run(
253253
submit_app,
254254
inputs,
255255
outputs,
@@ -868,7 +868,7 @@ def run(
868868
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
869869
)
870870

871-
super().run(
871+
return super().run(
872872
submit_app=submit_app,
873873
inputs=extended_inputs,
874874
outputs=extended_outputs,
@@ -1125,7 +1125,7 @@ def run(
11251125
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
11261126
)
11271127

1128-
super().run(
1128+
return super().run(
11291129
submit_app=submit_app,
11301130
inputs=extended_inputs,
11311131
outputs=extended_outputs,

src/sagemaker/transformer.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from sagemaker.job import _Job
1919
from sagemaker.session import Session
20+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
21+
from sagemaker.workflow import is_pipeline_variable
2022
from sagemaker.utils import base_name_from_image, name_from_base
2123

2224

@@ -106,6 +108,7 @@ def __init__(
106108

107109
self.sagemaker_session = sagemaker_session or Session()
108110

111+
@runnable_by_pipeline
109112
def transform(
110113
self,
111114
data,
@@ -197,7 +200,11 @@ def transform(
197200
base_name = self.base_transform_job_name
198201

199202
if base_name is None:
200-
base_name = self._retrieve_base_name()
203+
base_name = (
204+
"transform-job"
205+
if is_pipeline_variable(self.model_name)
206+
else self._retrieve_base_name()
207+
)
201208

202209
self._current_job_name = name_from_base(base_name)
203210

@@ -370,6 +377,7 @@ def start_new(
370377
experiment_config,
371378
model_client_config,
372379
)
380+
373381
transformer.sagemaker_session.transform(**transform_args)
374382

375383
return cls(transformer.sagemaker_session, transformer._current_job_name)

0 commit comments

Comments
 (0)