Skip to content

Commit 817d1a1

Browse files
authored
Merge branch 'master' into jobNamePrefix
2 parents dd646ea + 530d21b commit 817d1a1

15 files changed

+171
-30
lines changed

CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
## v2.46.1 (2021-06-22)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* Register model step tags
8+
9+
### Documentation Changes
10+
11+
* update to include new batch_get_record api call
12+
* Correct type annotation for TrainingStep inputs
13+
* introduce input mode FastFile
14+
* update hf transformer version
15+
316
## v2.46.0 (2021-06-15)
417

518
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.46.1.dev0
1+
2.46.2.dev0

doc/amazon_sagemaker_featurestore.rst

+7
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@ example identifier to retrieve the record.
291291
record_identifier_value = str(2990130)
292292
featurestore_runtime.get_record(FeatureGroupName=transaction_feature_group_name, RecordIdentifierValueAsString=record_identifier_value)
293293
294+
You can use the ``batch_get_record`` function to retrieve multiple records simultaneously from your feature store. The following example uses this API to retrieve a batch of records.
295+
296+
.. code:: python
297+
298+
record_identifier_values = ["573291", "109382", "828400", "124013"]
299+
featurestore_runtime.batch_get_record(Identifiers=[{"FeatureGroupName": transaction_feature_group_name, "RecordIdentifiersValueAsString": record_identifier_values}])
300+
294301
An example response from the fraud detection example:
295302
296303
.. code:: python

src/sagemaker/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,8 @@ def register(
10071007
if compile_model_family is not None:
10081008
model = self._compiled_models[compile_model_family]
10091009
else:
1010+
if "model_kms_key" not in kwargs:
1011+
kwargs["model_kms_key"] = self.output_kms_key
10101012
model = self.create_model(image_uri=image_uri, **kwargs)
10111013
model.name = model_name
10121014
return model.register(

src/sagemaker/inputs.py

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(
7070
a local directory.
7171
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via
7272
a Unix-named pipe.
73+
* 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of
74+
downloading the entire dataset before training begins.
7375
7476
attribute_names (list[str]): A list of one or more attribute names to use that are
7577
found in a specified AugmentedManifestFile.

src/sagemaker/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _get_model_package_args(
195195
marketplace_cert=False,
196196
approval_status=None,
197197
description=None,
198+
tags=None,
198199
):
199200
"""Get arguments for session.create_model_package method.
200201
@@ -250,6 +251,8 @@ def _get_model_package_args(
250251
model_package_args["approval_status"] = approval_status
251252
if description is not None:
252253
model_package_args["description"] = description
254+
if tags is not None:
255+
model_package_args["tags"] = tags
253256
return model_package_args
254257

255258
def _init_sagemaker_session_if_does_not_exist(self, instance_type):

src/sagemaker/session.py

+11
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ def train( # noqa: C901
467467
a directory in the Docker container.
468468
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
469469
Unix-named pipe.
470+
* 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of
471+
downloading the entire dataset before training begins.
470472
input_config (list): A list of Channel objects. Each channel is a named input source.
471473
Please refer to the format details described:
472474
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
@@ -609,6 +611,8 @@ def _get_train_request( # noqa: C901
609611
a directory in the Docker container.
610612
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
611613
Unix-named pipe.
614+
* 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of
615+
downloading the entire dataset before training begins.
612616
input_config (list): A list of Channel objects. Each channel is a named input source.
613617
Please refer to the format details described:
614618
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
@@ -1897,6 +1901,8 @@ def tune( # noqa: C901
18971901
a directory in the Docker container.
18981902
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
18991903
Unix-named pipe.
1904+
* 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of
1905+
downloading the entire dataset before training begins.
19001906
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
19011907
used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
19021908
the name of the metric, and 'Regex' for the regular expression used to extract the
@@ -2180,6 +2186,8 @@ def _map_training_config(
21802186
a directory in the Docker container.
21812187
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
21822188
Unix-named pipe.
2189+
* 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of
2190+
downloading the entire dataset before training begins.
21832191
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
21842192
jobs and APIs that create Amazon SageMaker endpoints use this role to access
21852193
training data and model artifacts. You must grant sufficient permissions to
@@ -2716,6 +2724,7 @@ def _get_create_model_package_request(
27162724
marketplace_cert=False,
27172725
approval_status="PendingManualApproval",
27182726
description=None,
2727+
tags=None,
27192728
):
27202729
"""Get request dictionary for CreateModelPackage API.
27212730
@@ -2753,6 +2762,8 @@ def _get_create_model_package_request(
27532762
request_dict["ModelPackageGroupName"] = model_package_group_name
27542763
if description is not None:
27552764
request_dict["ModelPackageDescription"] = description
2765+
if tags is not None:
2766+
request_dict["Tags"] = tags
27562767
if model_metrics:
27572768
request_dict["ModelMetrics"] = model_metrics
27582769
if metadata_properties:

src/sagemaker/workflow/_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
source_dir: str = None,
6161
dependencies: List = None,
6262
depends_on: List[str] = None,
63+
**kwargs,
6364
):
6465
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
6566
@@ -98,6 +99,7 @@ def __init__(
9899
"inference_script": self._entry_point_basename,
99100
"model_archive": self._model_archive,
100101
},
102+
**kwargs,
101103
)
102104
repacker.disable_profiler = True
103105
inputs = TrainingInput(self._model_prefix)
@@ -225,6 +227,7 @@ def __init__(
225227
compile_model_family=None,
226228
description=None,
227229
depends_on: List[str] = None,
230+
tags=None,
228231
**kwargs,
229232
):
230233
"""Constructor of a register model step.
@@ -264,6 +267,7 @@ def __init__(
264267
self.inference_instances = inference_instances
265268
self.transform_instances = transform_instances
266269
self.model_package_group_name = model_package_group_name
270+
self.tags = tags
267271
self.model_metrics = model_metrics
268272
self.metadata_properties = metadata_properties
269273
self.approval_status = approval_status
@@ -324,10 +328,12 @@ def arguments(self) -> RequestType:
324328
metadata_properties=self.metadata_properties,
325329
approval_status=self.approval_status,
326330
description=self.description,
331+
tags=self.tags,
327332
)
328333
request_dict = model.sagemaker_session._get_create_model_package_request(
329334
**model_package_args
330335
)
336+
331337
# these are not available in the workflow service and will cause rejection
332338
if "CertifyForMarketplace" in request_dict:
333339
request_dict.pop("CertifyForMarketplace")

src/sagemaker/workflow/callback_step.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,20 @@ def to_request(self) -> RequestType:
5858
"OutputType": self.output_type.value,
5959
}
6060

61-
@property
62-
def expr(self) -> Dict[str, str]:
63-
"""The 'Get' expression dict for a `Parameter`."""
64-
return CallbackOutput._expr(self.output_name)
61+
def expr(self, step_name) -> Dict[str, str]:
62+
"""The 'Get' expression dict for a `CallbackOutput`."""
63+
return CallbackOutput._expr(self.output_name, step_name)
6564

6665
@classmethod
67-
def _expr(cls, name):
66+
def _expr(cls, name, step_name):
6867
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
6968
7069
Args:
7170
name (str): The name of the callback output.
71+
step_name (str): The name of the step the callback step associated
72+
with this output belongs to.
7273
"""
73-
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
74+
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}
7475

7576

7677
class CallbackStep(Step):

src/sagemaker/workflow/pipeline.py

+36-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27-
from sagemaker.workflow.callback_step import CallbackOutput
27+
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
2828
from sagemaker.workflow.entities import (
2929
Entity,
3030
Expression,
@@ -240,9 +240,12 @@ def definition(self) -> str:
240240
"""Converts a request structure to string representation for workflow service calls."""
241241
request_dict = self.to_request()
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243-
request_dict["PipelineExperimentConfig"]
243+
request_dict["PipelineExperimentConfig"], {}
244+
)
245+
callback_output_to_step_map = _map_callback_outputs(self.steps)
246+
request_dict["Steps"] = interpolate(
247+
request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map
244248
)
245-
request_dict["Steps"] = interpolate(request_dict["Steps"])
246249

247250
return json.dumps(request_dict)
248251

@@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263266
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]
264267

265268

266-
def interpolate(request_obj: RequestType) -> RequestType:
269+
def interpolate(
270+
request_obj: RequestType, callback_output_to_step_map: Dict[str, str]
271+
) -> RequestType:
267272
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268273
269274
Args:
270275
request_obj (RequestType): The request dict.
276+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
271277
272278
Returns:
273279
RequestType: The request dict with Parameter values replaced by their expression.
274280
"""
275281
request_obj_copy = deepcopy(request_obj)
276-
return _interpolate(request_obj_copy)
282+
return _interpolate(request_obj_copy, callback_output_to_step_map=callback_output_to_step_map)
277283

278284

279-
def _interpolate(obj: Union[RequestType, Any]):
285+
def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict[str, str]):
280286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281287
282288
Args:
283289
obj (Union[RequestType, Any]): The request dict.
290+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
284291
"""
285-
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
292+
if isinstance(obj, (Expression, Parameter, Properties)):
286293
return obj.expr
294+
if isinstance(obj, CallbackOutput):
295+
step_name = callback_output_to_step_map[obj.output_name]
296+
return obj.expr(step_name)
287297
if isinstance(obj, dict):
288298
new = obj.__class__()
289299
for key, value in obj.items():
290-
new[key] = interpolate(value)
300+
new[key] = interpolate(value, callback_output_to_step_map)
291301
elif isinstance(obj, (list, set, tuple)):
292-
new = obj.__class__(interpolate(value) for value in obj)
302+
new = obj.__class__(interpolate(value, callback_output_to_step_map) for value in obj)
293303
else:
294304
return obj
295305
return new
296306

297307

308+
def _map_callback_outputs(steps: List[Step]):
309+
"""Iterate over the provided steps, building a map of callback output parameters to step names.
310+
311+
Args:
312+
step (List[Step]): The steps list.
313+
"""
314+
315+
callback_output_map = {}
316+
for step in steps:
317+
if isinstance(step, CallbackStep):
318+
if step.outputs:
319+
for output in step.outputs:
320+
callback_output_map[output.output_name] = step.name
321+
322+
return callback_output_map
323+
324+
298325
def update_args(args: Dict[str, Any], **kwargs):
299326
"""Updates the request arguments dict with a value, if populated.
300327

src/sagemaker/workflow/step_collections.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
image_uri=None,
6868
compile_model_family=None,
6969
description=None,
70+
tags=None,
7071
**kwargs,
7172
):
7273
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,15 +95,21 @@ def __init__(
9495
compile_model_family (str): The instance family for the compiled model. If
9596
specified, a compiled model is used (default: None).
9697
description (str): Model Package description (default: None).
98+
tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note
99+
that tags will only be applied to newly created model package groups; if the
100+
name of an existing group is passed to "model_package_group_name",
101+
tags will not be applied.
97102
**kwargs: additional arguments to `create_model`.
98103
"""
99104
steps: List[Step] = []
100105
repack_model = False
101106
if "entry_point" in kwargs:
102107
repack_model = True
103-
entry_point = kwargs["entry_point"]
108+
entry_point = kwargs.pop("entry_point", None)
104109
source_dir = kwargs.get("source_dir")
105110
dependencies = kwargs.get("dependencies")
111+
kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))
112+
106113
repack_model_step = _RepackModelStep(
107114
name=f"{name}RepackModel",
108115
depends_on=depends_on,
@@ -111,6 +118,7 @@ def __init__(
111118
entry_point=entry_point,
112119
source_dir=source_dir,
113120
dependencies=dependencies,
121+
**kwargs,
114122
)
115123
steps.append(repack_model_step)
116124
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
@@ -119,6 +127,7 @@ def __init__(
119127
kwargs.pop("entry_point", None)
120128
kwargs.pop("source_dir", None)
121129
kwargs.pop("dependencies", None)
130+
kwargs.pop("output_kms_key", None)
122131

123132
register_model_step = _RegisterModelStep(
124133
name=name,
@@ -134,6 +143,7 @@ def __init__(
134143
image_uri=image_uri,
135144
compile_model_family=compile_model_family,
136145
description=description,
146+
tags=tags,
137147
**kwargs,
138148
)
139149
if not repack_model:

src/sagemaker/workflow/steps.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616
import abc
1717

1818
from enum import Enum
19-
from typing import Dict, List
19+
from typing import Dict, List, Union
2020

2121
import attr
2222

2323
from sagemaker.estimator import EstimatorBase, _TrainingJob
24-
from sagemaker.inputs import (
25-
CreateModelInput,
26-
TrainingInput,
27-
TransformInput,
28-
)
24+
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
2925
from sagemaker.model import Model
3026
from sagemaker.processing import (
3127
ProcessingInput,
@@ -145,7 +141,7 @@ def __init__(
145141
self,
146142
name: str,
147143
estimator: EstimatorBase,
148-
inputs: TrainingInput = None,
144+
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
149145
cache_config: CacheConfig = None,
150146
depends_on: List[str] = None,
151147
):
@@ -157,7 +153,23 @@ def __init__(
157153
Args:
158154
name (str): The name of the training step.
159155
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
160-
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
156+
inputs (str or dict or sagemaker.inputs.TrainingInput
157+
or sagemaker.inputs.FileSystemInput): Information
158+
about the training data. This can be one of three types:
159+
160+
* (str) the S3 location where training data is saved, or a file:// path in
161+
local mode.
162+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
163+
channels for training data, you can specify a dict mapping channel names to
164+
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
165+
* (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
166+
that can provide additional information as well as the path to the training
167+
dataset.
168+
See :func:`sagemaker.inputs.TrainingInput` for full details.
169+
* (sagemaker.inputs.FileSystemInput) - channel configuration for
170+
a file system data source that can provide additional information as well as
171+
the path to the training dataset.
172+
161173
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
162174
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
163175
depends on

0 commit comments

Comments
 (0)