Skip to content

Commit b31a0fe

Browse files
authored
Merge branch 'master' into feature/new_fg_utils
2 parents 7c02a32 + a94a3b1 commit b31a0fe

25 files changed

+1179
-108
lines changed

CHANGELOG.md

+16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
# Changelog
22

3+
## v2.130.0 (2023-01-26)
4+
5+
### Features
6+
7+
* Add PyTorch 1.13.1 to SDK
8+
* Adding image_uri config for DJL containers
9+
* Support specifying env-vars when creating model from model package
10+
* local download dir for Model and Estimator classes
11+
12+
### Bug Fixes and Other Changes
13+
14+
* increase creation time slack minutes
15+
* Enable load_run auto pass in experiment config
16+
* Add us-isob-east-1 accounts and configs
17+
* Clean up Pipeline unit tests
18+
319
## v2.129.0 (2023-01-19)
420

521
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.129.1.dev0
1+
2.130.1.dev0

src/sagemaker/estimator.py

+70-4
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def __init__(
155155
entry_point: Optional[Union[str, PipelineVariable]] = None,
156156
dependencies: Optional[List[Union[str]]] = None,
157157
instance_groups: Optional[List[InstanceGroup]] = None,
158+
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
159+
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
158160
**kwargs,
159161
):
160162
"""Initialize an ``EstimatorBase`` instance.
@@ -489,6 +491,18 @@ def __init__(
489491
`Train Using a Heterogeneous Cluster
490492
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
491493
in the *Amazon SageMaker developer guide*.
494+
training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
495+
Docker image that contains the training algorithm (default: None).
496+
Set this to one of the following values:
497+
* 'Platform' - The training image is hosted in Amazon ECR.
498+
* 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
499+
When it's default to None, its behavior will be same as 'Platform' - image is hosted
500+
in ECR.
501+
training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
502+
(ARN) of an AWS Lambda function that provides credentials to authenticate to the
503+
private Docker registry where your training image is hosted (default: None).
504+
When it's set to None, SageMaker will not do authentication before pulling the image
505+
in the private Docker registry.
492506
"""
493507
instance_count = renamed_kwargs(
494508
"train_instance_count", "instance_count", instance_count, kwargs
@@ -536,7 +550,9 @@ def __init__(
536550
self.dependencies = dependencies or []
537551
self.uploaded_code = None
538552
self.tags = add_jumpstart_tags(
539-
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
553+
tags=tags,
554+
training_model_uri=self.model_uri,
555+
training_script_uri=self.source_dir,
540556
)
541557
if self.instance_type in ("local", "local_gpu"):
542558
if self.instance_type == "local_gpu" and self.instance_count > 1:
@@ -571,6 +587,12 @@ def __init__(
571587
self.subnets = subnets
572588
self.security_group_ids = security_group_ids
573589

590+
# training image configs
591+
self.training_repository_access_mode = training_repository_access_mode
592+
self.training_repository_credentials_provider_arn = (
593+
training_repository_credentials_provider_arn
594+
)
595+
574596
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
575597
self.use_spot_instances = use_spot_instances
576598
self.max_wait = max_wait
@@ -651,7 +673,8 @@ def _ensure_base_job_name(self):
651673
self.base_job_name
652674
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
653675
or base_name_from_image(
654-
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
676+
self.training_image_uri(),
677+
default_base_name=EstimatorBase.JOB_CLASS_NAME,
655678
)
656679
)
657680

@@ -1405,7 +1428,10 @@ def deploy(
14051428
self._ensure_base_job_name()
14061429

14071430
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
1408-
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
1431+
kwargs.get("source_dir"),
1432+
self.source_dir,
1433+
kwargs.get("model_data"),
1434+
self.model_uri,
14091435
)
14101436
default_name = (
14111437
name_from_base(jumpstart_base_name)
@@ -1638,6 +1664,15 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
16381664
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
16391665
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
16401666
init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
1667+
if "TrainingImageConfig" in job_details["AlgorithmSpecification"]:
1668+
init_params["training_repository_access_mode"] = job_details[
1669+
"AlgorithmSpecification"
1670+
]["TrainingImageConfig"].get("TrainingRepositoryAccessMode")
1671+
init_params["training_repository_credentials_provider_arn"] = (
1672+
job_details["AlgorithmSpecification"]["TrainingImageConfig"]
1673+
.get("TrainingRepositoryAuthConfig", {})
1674+
.get("TrainingRepositoryCredentialsProviderArn")
1675+
)
16411676
else:
16421677
raise RuntimeError(
16431678
"Invalid AlgorithmSpecification. Either TrainingImage or "
@@ -2118,6 +2153,17 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
21182153
else:
21192154
train_args["retry_strategy"] = None
21202155

2156+
if estimator.training_repository_access_mode is not None:
2157+
training_image_config = {
2158+
"TrainingRepositoryAccessMode": estimator.training_repository_access_mode
2159+
}
2160+
if estimator.training_repository_credentials_provider_arn is not None:
2161+
training_image_config["TrainingRepositoryAuthConfig"] = {}
2162+
training_image_config["TrainingRepositoryAuthConfig"][
2163+
"TrainingRepositoryCredentialsProviderArn"
2164+
] = estimator.training_repository_credentials_provider_arn
2165+
train_args["training_image_config"] = training_image_config
2166+
21212167
# encrypt_inter_container_traffic may be a pipeline variable place holder object
21222168
# which is parsed in execution time
21232169
if estimator.encrypt_inter_container_traffic:
@@ -2182,7 +2228,11 @@ def _is_local_channel(cls, input_uri):
21822228

21832229
@classmethod
21842230
def update(
2185-
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
2231+
cls,
2232+
estimator,
2233+
profiler_rule_configs=None,
2234+
profiler_config=None,
2235+
resource_config=None,
21862236
):
21872237
"""Update a running Amazon SageMaker training job.
21882238
@@ -2321,6 +2371,8 @@ def __init__(
23212371
entry_point: Optional[Union[str, PipelineVariable]] = None,
23222372
dependencies: Optional[List[str]] = None,
23232373
instance_groups: Optional[List[InstanceGroup]] = None,
2374+
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
2375+
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
23242376
**kwargs,
23252377
):
23262378
"""Initialize an ``Estimator`` instance.
@@ -2654,6 +2706,18 @@ def __init__(
26542706
`Train Using a Heterogeneous Cluster
26552707
<https://docs.aws.amazon.com/sagemaker/latest/dg/train-heterogeneous-cluster.html>`_
26562708
in the *Amazon SageMaker developer guide*.
2709+
training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the
2710+
Docker image that contains the training algorithm (default: None).
2711+
Set this to one of the following values:
2712+
* 'Platform' - The training image is hosted in Amazon ECR.
2713+
* 'Vpc' - The training image is hosted in a private Docker registry in your VPC.
2714+
When it's default to None, its behavior will be same as 'Platform' - image is hosted
2715+
in ECR.
2716+
training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name
2717+
(ARN) of an AWS Lambda function that provides credentials to authenticate to the
2718+
private Docker registry where your training image is hosted (default: None).
2719+
When it's set to None, SageMaker will not do authentication before pulling the image
2720+
in the private Docker registry.
26572721
"""
26582722
self.image_uri = image_uri
26592723
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -2698,6 +2762,8 @@ def __init__(
26982762
dependencies=dependencies,
26992763
hyperparameters=hyperparameters,
27002764
instance_groups=instance_groups,
2765+
training_repository_access_mode=training_repository_access_mode,
2766+
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
27012767
**kwargs,
27022768
)
27032769

src/sagemaker/feature_store/feature_group.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -710,20 +710,25 @@ def ingest(
710710
) -> IngestionManagerPandas:
711711
"""Ingest the content of a pandas DataFrame to feature store.
712712
713-
``max_worker`` number of thread will be created to work on different partitions of
714-
the ``data_frame`` in parallel.
713+
``max_worker`` the number of threads created to work on different partitions of the
714+
``data_frame`` in parallel.
715715
716-
``max_processes`` number of processes will be created to work on different partitions
717-
of the ``data_frame`` in parallel, each with ``max_worker`` threads.
716+
``max_processes`` the number of processes will be created to work on different
717+
partitions of the ``data_frame`` in parallel, each with ``max_worker`` threads.
718718
719-
The ingest function will attempt to ingest all records in the data frame. If ``wait``
720-
is True, then an exception is thrown after all records have been processed. If ``wait``
721-
is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
722-
function will throw an exception.
719+
The ingest function attempts to ingest all records in the data frame. SageMaker
720+
Feature Store throws an exception if it fails to ingest any records.
723721
724-
Zero based indices of rows that failed to be ingested can be found in the exception.
725-
They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
726-
the exception is thrown.
722+
If ``wait`` is ``True``, Feature Store runs the ``ingest`` function synchronously.
723+
You receive an ``IngestionError`` if there are any records that can't be ingested.
724+
If ``wait`` is ``False``, Feature Store runs the ``ingest`` function asynchronously.
725+
726+
Instead of setting ``wait`` to ``True`` in the ``ingest`` function, you can invoke
727+
the ``wait`` function on the returned instance of ``IngestionManagerPandas`` to run
728+
the ``ingest`` function synchronously.
729+
730+
To access the rows that failed to ingest, set ``wait`` to ``False``. The
731+
``IngestionError.failed_rows`` object saves all of the rows that failed to ingest.
727732
728733
`profile_name` argument is an optional one. It will use the default credential if None is
729734
passed. This `profile_name` is used in the sagemaker_featurestore_runtime client only. See

src/sagemaker/feature_store/feature_store.py

+66
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
from sagemaker import Session
2727
from sagemaker.feature_store.dataset_builder import DatasetBuilder
2828
from sagemaker.feature_store.feature_group import FeatureGroup
29+
from sagemaker.feature_store.inputs import (
30+
Filter,
31+
ResourceEnum,
32+
SearchOperatorEnum,
33+
SortOrderEnum,
34+
Identifier,
35+
)
2936

3037

3138
@attr.s
@@ -114,6 +121,7 @@ def list_feature_groups(
114121
sort_by (str): The value on which the FeatureGroup list is sorted.
115122
max_results (int): The maximum number of results returned by ListFeatureGroups.
116123
next_token (str): A token to resume pagination of ListFeatureGroups results.
124+
117125
Returns:
118126
Response dict from service.
119127
"""
@@ -128,3 +136,61 @@ def list_feature_groups(
128136
max_results=max_results,
129137
next_token=next_token,
130138
)
139+
140+
def batch_get_record(self, identifiers: Sequence[Identifier]) -> Dict[str, Any]:
141+
"""Get record in batch from FeatureStore
142+
143+
Args:
144+
identifiers (Sequence[Identifier]): A list of identifiers to uniquely identify records
145+
in FeatureStore.
146+
147+
Returns:
148+
Response dict from service.
149+
"""
150+
batch_get_record_identifiers = [identifier.to_dict() for identifier in identifiers]
151+
return self.sagemaker_session.batch_get_record(identifiers=batch_get_record_identifiers)
152+
153+
def search(
154+
self,
155+
resource: ResourceEnum,
156+
filters: Sequence[Filter] = None,
157+
operator: SearchOperatorEnum = None,
158+
sort_by: str = None,
159+
sort_order: SortOrderEnum = None,
160+
next_token: str = None,
161+
max_results: int = None,
162+
) -> Dict[str, Any]:
163+
"""Search for FeatureGroups or FeatureMetadata satisfying given filters.
164+
165+
Args:
166+
resource (ResourceEnum): The name of the Amazon SageMaker resource to search for.
167+
Valid values are ``FeatureGroup`` or ``FeatureMetadata``.
168+
filters (Sequence[Filter]): A list of filter objects (Default: None).
169+
operator (SearchOperatorEnum): A Boolean operator used to evaluate the filters.
170+
Valid values are ``And`` or ``Or``. The default is ``And`` (Default: None).
171+
sort_by (str): The name of the resource property used to sort the ``SearchResults``.
172+
The default is ``LastModifiedTime``.
173+
sort_order (SortOrderEnum): How ``SearchResults`` are ordered.
174+
Valid values are ``Ascending`` or ``Descending``. The default is ``Descending``.
175+
next_token (str): If more than ``MaxResults`` resources match the specified
176+
filters, the response includes a ``NextToken``. The ``NextToken`` can be passed to
177+
the next ``SearchRequest`` to continue retrieving results (Default: None).
178+
max_results (int): The maximum number of results to return (Default: None).
179+
180+
Returns:
181+
Response dict from service.
182+
"""
183+
search_expression = {}
184+
if filters:
185+
search_expression["Filters"] = [filter.to_dict() for filter in filters]
186+
if operator:
187+
search_expression["Operator"] = str(operator)
188+
189+
return self.sagemaker_session.search(
190+
resource=str(resource),
191+
search_expression=search_expression,
192+
sort_by=sort_by,
193+
sort_order=None if not sort_order else str(sort_order),
194+
next_token=next_token,
195+
max_results=max_results,
196+
)

0 commit comments

Comments
 (0)